Deprecation of Triton-C and Replacement by decorated Python functions (#86)
This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes. See documentations for more information on the new API
This commit is contained in:
committed by
Philippe Tillet
parent
1fdb465b71
commit
39f4730305
@@ -1,30 +1,31 @@
|
||||
#ifndef _TRITON_CODEGEN_PASS_H_
|
||||
#define _TRITON_CODEGEN_PASS_H_
|
||||
|
||||
#include <list>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace ir{
|
||||
class module;
|
||||
}
|
||||
namespace driver{
|
||||
class device;
|
||||
class module;
|
||||
class kernel;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
class pass {
|
||||
public:
|
||||
virtual void run(ir::module& m);
|
||||
};
|
||||
// TODO:
|
||||
// There should be a proper pass manager there!
|
||||
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
|
||||
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
|
||||
|
||||
|
||||
class pass_manager {
|
||||
public:
|
||||
void add(pass* p);
|
||||
void run(ir::module& m);
|
||||
|
||||
private:
|
||||
std::list<pass*> passes;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -119,7 +119,7 @@ public:
|
||||
void visit_exp_inst(ir::exp_inst*);
|
||||
void visit_log_inst(ir::log_inst*);
|
||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
|
||||
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
|
||||
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
|
||||
void visit_atomic_add_inst(ir::atomic_add_inst*);
|
||||
|
@@ -59,7 +59,7 @@ public:
|
||||
|
||||
// CUDA
|
||||
class cu_module: public module {
|
||||
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
|
||||
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
|
||||
void init_from_ptx(const std::string& ptx);
|
||||
|
||||
public:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#pragma once
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_BUILDER_H_
|
||||
#define _TRITON_IR_BUILDER_H_
|
||||
@@ -27,6 +27,8 @@ class builder{
|
||||
public:
|
||||
// Constructor
|
||||
builder(context &ctx);
|
||||
// Getters
|
||||
const context& get_context() { return ctx_; }
|
||||
// Setters
|
||||
void set_insert_point(iterator instr);
|
||||
void set_insert_point(instruction* i);
|
||||
@@ -38,6 +40,9 @@ public:
|
||||
value *get_int1(bool val);
|
||||
value *get_int32(int32_t val);
|
||||
value *get_int64(int64_t val);
|
||||
value *get_float16(float val);
|
||||
value *get_float32(float val);
|
||||
value *get_range(int32_t lo, int32_t hi);
|
||||
// Types
|
||||
type *get_void_ty();
|
||||
type *get_int1_ty();
|
||||
@@ -50,11 +55,10 @@ public:
|
||||
type *get_double_ty();
|
||||
// Insert
|
||||
template<typename InstTy>
|
||||
InstTy* insert(InstTy *inst, const std::string &name = ""){
|
||||
InstTy* insert(InstTy *inst){
|
||||
assert(block_);
|
||||
block_->get_inst_list().insert(insert_point_, inst);
|
||||
inst->set_parent(block_);
|
||||
inst->set_name(name);
|
||||
// for(ir::value* op: inst->ops())
|
||||
// op->add_use(inst);
|
||||
return inst;
|
||||
@@ -64,91 +68,87 @@ public:
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Cast instructions
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_to_ui(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
|
||||
value *create_downcast(value *arg, const std::string &name = "");
|
||||
value *create_cast(cast_op_t op, value *v, type *dst_ty);
|
||||
value* create_ptr_to_int(value *src, type *dst_ty);
|
||||
value* create_si_to_fp(value *src, type *dst_ty);
|
||||
value* create_ui_to_fp(value *src, type *dst_ty);
|
||||
value* create_fp_to_si(value *src, type *dst_ty);
|
||||
value* create_fp_to_ui(value *src, type *dst_ty);
|
||||
value* create_fp_ext(value *src, type *dst_ty);
|
||||
value* create_fp_trunc(value *src, type *dst_ty);
|
||||
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
|
||||
value *create_downcast(value *arg);
|
||||
// Phi instruction
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
|
||||
phi_node* create_phi(type *ty, unsigned num_reserved);
|
||||
// Binary instructions
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fadd(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fsub(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_mul(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sdiv(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_udiv(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_srem(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_urem(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_add(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_lshr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_ashr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
|
||||
value *create_fmul(value *lhs, value *rhs);
|
||||
value *create_fdiv(value *lhs, value *rhs);
|
||||
value *create_frem(value *lhs, value *rhs);
|
||||
value *create_fadd(value *lhs, value *rhs);
|
||||
value *create_fsub(value *lhs, value *rhs);
|
||||
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sdiv(value *lhs, value *rhs);
|
||||
value *create_udiv(value *lhs, value *rhs);
|
||||
value *create_srem(value *lhs, value *rhs);
|
||||
value *create_urem(value *lhs, value *rhs);
|
||||
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
|
||||
// GEP
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
|
||||
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
|
||||
// Comparison (int)
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpSGT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpULE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpULT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpUGE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpUGT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_icmpSLE(value *lhs, value *rhs);
|
||||
value *create_icmpSLT(value *lhs, value *rhs);
|
||||
value *create_icmpSGE(value *lhs, value *rhs);
|
||||
value *create_icmpSGT(value *lhs, value *rhs);
|
||||
value *create_icmpULE(value *lhs, value *rhs);
|
||||
value *create_icmpULT(value *lhs, value *rhs);
|
||||
value *create_icmpUGE(value *lhs, value *rhs);
|
||||
value *create_icmpUGT(value *lhs, value *rhs);
|
||||
value *create_icmpEQ(value *lhs, value *rhs);
|
||||
value *create_icmpNE(value *lhs, value *rhs);
|
||||
// Comparison (float)
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOGE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmpONE(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
|
||||
value *create_fcmpOLT(value *lhs, value *rhs);
|
||||
value *create_fcmpOGT(value *lhs, value *rhs);
|
||||
value *create_fcmpOLE(value *lhs, value *rhs);
|
||||
value *create_fcmpOGE(value *lhs, value *rhs);
|
||||
value *create_fcmpOEQ(value *lhs, value *rhs);
|
||||
value *create_fcmpONE(value *lhs, value *rhs);
|
||||
// Logical
|
||||
value *create_and(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
|
||||
value *create_or(value *lhs, value *rhs, const std::string &name = "");
|
||||
// Unary
|
||||
// value *create_fneg(value *arg, const std::string &name = "");
|
||||
// value *create_neg(value *arg, const std::string &name = "");
|
||||
// value *create_not(value *arg, const std::string &name = "");
|
||||
value *create_and(value *lhs, value *rhs);
|
||||
value *create_xor(value *lhs, value *rhs);
|
||||
value *create_or(value *lhs, value *rhs);
|
||||
// Input/Output
|
||||
value *create_load(value *arg, const std::string &name = "");
|
||||
value *create_store(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||
value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = "");
|
||||
// Tile instruction
|
||||
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_load(value *arg);
|
||||
value *create_store(value *ptr, value *val);
|
||||
value *create_masked_load(value *arg, value *mask, value *false_value);
|
||||
value *create_masked_store(value *ptr, value *val, value *mask);
|
||||
// Block instruction
|
||||
value *create_splat(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
|
||||
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
|
||||
// Built-in instruction
|
||||
value *create_get_program_id(unsigned axis, const std::string &name = "");
|
||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
|
||||
value *create_exp(value* arg, const std::string &name = "");
|
||||
value *create_log(value* arg, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
value *create_get_program_id(unsigned axis);
|
||||
value *create_get_num_programs(unsigned axis);
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val);
|
||||
value *create_atomic_exch(value *ptr, value *val);
|
||||
value *create_atomic_add(value *ptr, value *val, value *msk);
|
||||
value *create_exp(value* arg);
|
||||
value *create_log(value* arg);
|
||||
value *create_dot(value *A, value *B, value *C);
|
||||
value *create_trans(value *A, const std::vector<int> &perm = {});
|
||||
value *create_sqrt(value *A);
|
||||
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
|
||||
value *create_select(value *pred, value *if_value, value *else_value);
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
|
||||
value *create_copy_from_shared(value *arg, const std::string &name = "");
|
||||
value *create_copy_to_shared(value *arg);
|
||||
value *create_masked_load_async(value *arg, value *mask, value *false_value);
|
||||
value *create_copy_from_shared(value *arg);
|
||||
value *create_barrier(const std::string &name = "");
|
||||
value *create_async_wait(int N);
|
||||
|
||||
@@ -158,6 +158,7 @@ private:
|
||||
iterator insert_point_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -9,6 +9,7 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class builder;
|
||||
class type;
|
||||
class context_impl;
|
||||
|
||||
@@ -16,8 +17,11 @@ class context_impl;
|
||||
class context {
|
||||
public:
|
||||
context();
|
||||
context(const context&) = delete;
|
||||
context& operator=(const context&) = delete;
|
||||
|
||||
public:
|
||||
ir::builder* builder = nullptr;
|
||||
std::shared_ptr<context_impl> p_impl;
|
||||
};
|
||||
|
||||
|
@@ -28,13 +28,16 @@ public:
|
||||
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
|
||||
// Pointer types
|
||||
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
|
||||
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
|
||||
// Block types
|
||||
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
|
||||
|
||||
// Int constants
|
||||
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
|
||||
// Float constants
|
||||
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
|
||||
// undef values
|
||||
std::map<type*, undef_value*> uv_constants_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
97
include/triton/ir/dispatch.h
Normal file
97
include/triton/ir/dispatch.h
Normal file
@@ -0,0 +1,97 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_IR_DISPATCH_H_
|
||||
#define _TRITON_IR_DISPATCH_H_
|
||||
|
||||
#include "triton/ir/builder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
|
||||
/*----------------------------------------------
|
||||
higher level functions that follow the likely
|
||||
semantics of most expected frontends
|
||||
----------------------------------------------*/
|
||||
|
||||
struct semantic_error: public std::runtime_error {
|
||||
semantic_error(const std::string& msg):
|
||||
std::runtime_error(msg) { }
|
||||
};
|
||||
|
||||
struct dispatch{
|
||||
typedef ir::type::block_shapes_t shape_t;
|
||||
|
||||
|
||||
// programming model
|
||||
static ir::value *program_id(int axis, ir::builder *builder);
|
||||
static ir::value *num_programs(int axis, ir::builder *builder);
|
||||
|
||||
// binary operators
|
||||
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// unary operators
|
||||
static ir::value *plus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *minus(ir::value *input, ir::builder *builder);
|
||||
static ir::value *invert(ir::value *input, ir::builder *builder);
|
||||
|
||||
// comparison operators
|
||||
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
|
||||
|
||||
// block creation
|
||||
static ir::value* arange(int start, int end, ir::builder *builder);
|
||||
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
|
||||
|
||||
|
||||
// casting ops
|
||||
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
|
||||
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
|
||||
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
|
||||
|
||||
// memory operators
|
||||
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
|
||||
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
|
||||
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
|
||||
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
|
||||
|
||||
// linear algebra
|
||||
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
|
||||
|
||||
// indexing
|
||||
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
|
||||
|
||||
// reduction
|
||||
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
|
||||
|
||||
// math
|
||||
static ir::value *exp(ir::value *x, ir::builder *builder);
|
||||
static ir::value *log(ir::value *x, ir::builder *builder);
|
||||
static ir::value *sqrt(ir::value *x, ir::builder *builder);
|
||||
|
||||
// internal (debug/optimization)
|
||||
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
|
||||
static ir::value *debug_barrier(ir::builder *builder);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -35,7 +35,7 @@ private:
|
||||
|
||||
/* Attribute */
|
||||
enum attribute_kind_t {
|
||||
readonly,
|
||||
readonly = 0,
|
||||
writeonly,
|
||||
noalias,
|
||||
aligned,
|
||||
@@ -71,7 +71,7 @@ public:
|
||||
case writeonly: return ".writeonly";
|
||||
case noalias: return ".noalias";
|
||||
case aligned: return ".aligned(" + std::to_string(value_) + ")";
|
||||
case multiple_of: return ".readonly";
|
||||
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
|
||||
case retune: return ".retunr";
|
||||
default: break;
|
||||
}
|
||||
@@ -102,7 +102,7 @@ private:
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const args_t &args() { return args_; }
|
||||
const args_t &args() const { return args_; }
|
||||
function_type* get_fn_type() { return fn_ty_; }
|
||||
|
||||
// factory methods
|
||||
|
@@ -514,7 +514,7 @@ public:
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
// reshape
|
||||
@@ -525,7 +525,7 @@ private:
|
||||
std::string repr_impl() const { return "reshape"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
_TRITON_DEFINE_ACCEPT(reshape_inst)
|
||||
@@ -539,7 +539,7 @@ private:
|
||||
std::string repr_impl() const { return "splat"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
_TRITON_DEFINE_ACCEPT(splat_inst)
|
||||
@@ -553,7 +553,7 @@ private:
|
||||
std::string repr_impl() const { return "broadcast"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
_TRITON_DEFINE_ACCEPT(broadcast_inst)
|
||||
@@ -597,16 +597,16 @@ private:
|
||||
unsigned axis_;
|
||||
};
|
||||
|
||||
class get_num_program_inst: public builtin_inst {
|
||||
class get_num_programs_inst: public builtin_inst {
|
||||
private:
|
||||
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
|
||||
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_program_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_program_inst)
|
||||
_TRITON_DEFINE_CLONE(get_num_programs_inst)
|
||||
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
|
@@ -36,6 +36,11 @@ class alloc_const;
|
||||
|
||||
/* Module */
|
||||
struct scope {
|
||||
public:
|
||||
const std::map<std::string, ir::value*>& get_values() { return values; }
|
||||
void set_type(const std::string& name, ir::type* ty) { types[name] = ty; }
|
||||
ir::type* get_type(const std::string& name) { return types.at(name); }
|
||||
private:
|
||||
std::map<std::string, ir::type*> types;
|
||||
std::map<std::string, ir::value*> values;
|
||||
};
|
||||
@@ -61,8 +66,7 @@ private:
|
||||
void push_function(function *fn) { functions_.push_back(fn); }
|
||||
|
||||
public:
|
||||
module(const std::string &name);
|
||||
context& get_context();
|
||||
module(const std::string &name, builder& builder);
|
||||
builder& get_builder();
|
||||
// Setters
|
||||
void set_value(const std::string& name, basic_block* block, value *x);
|
||||
@@ -95,8 +99,7 @@ public:
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
context context_;
|
||||
builder builder_;
|
||||
builder& builder_;
|
||||
std::map<val_key_t, value*> values_;
|
||||
std::map<val_key_t, type*> types_;
|
||||
std::set<std::string> const_;
|
||||
|
@@ -18,7 +18,7 @@ class constant_int;
|
||||
/* Type */
|
||||
class type {
|
||||
public:
|
||||
typedef std::vector<unsigned> tile_shapes_t;
|
||||
typedef std::vector<unsigned> block_shapes_t;
|
||||
|
||||
protected:
|
||||
typedef std::vector<type*> contained_tys_vec_t;
|
||||
@@ -43,7 +43,7 @@ public:
|
||||
FunctionTyID, ///< 11: Functions
|
||||
PointerTyID, ///< 12: Pointers
|
||||
StructTyID, ///< 13: Struct
|
||||
TileTyID, ///< 14: Tile
|
||||
BlockTyID, ///< 14: Tile
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -62,7 +62,7 @@ public:
|
||||
unsigned get_tile_bitwidth() const;
|
||||
unsigned get_primitive_size_in_bits() const;
|
||||
type *get_scalar_ty() const;
|
||||
const tile_shapes_t& get_tile_shapes() const;
|
||||
block_shapes_t get_block_shapes() const;
|
||||
const size_t get_tile_rank() const;
|
||||
const size_t get_tile_ranks1() const;
|
||||
unsigned get_tile_num_elements() const;
|
||||
@@ -83,7 +83,7 @@ public:
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_bool_ty() const { return is_integer_ty(1); }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||
bool is_block_ty() const { return id_ == BlockTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
@@ -110,7 +110,7 @@ public:
|
||||
// repr
|
||||
std::string tile_repr() const {
|
||||
std::string res = get_tile_element_ty()->repr();
|
||||
auto shapes = get_tile_shapes();
|
||||
auto shapes = get_block_shapes();
|
||||
res += "<";
|
||||
for(size_t i = 0; i < shapes.size(); i++){
|
||||
if(i > 0)
|
||||
@@ -137,7 +137,7 @@ public:
|
||||
case FunctionTyID: return "fn";
|
||||
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
|
||||
case StructTyID: return "struct";
|
||||
case TileTyID: return tile_repr();
|
||||
case BlockTyID: return tile_repr();
|
||||
default: break;
|
||||
}
|
||||
assert(false);
|
||||
@@ -180,23 +180,23 @@ public:
|
||||
type* get_type_at_index(value *idx) const;
|
||||
};
|
||||
|
||||
class tile_type: public composite_type {
|
||||
class block_type: public composite_type {
|
||||
private:
|
||||
tile_type(type *ty, const tile_shapes_t &shapes);
|
||||
block_type(type *ty, const block_shapes_t &shapes);
|
||||
static bool is_valid_elt_ty(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
const tile_shapes_t& get_shapes() const { return shapes_; }
|
||||
const block_shapes_t& get_shapes() const { return shapes_; }
|
||||
unsigned get_num_elements() const;
|
||||
unsigned get_bitwidth() const;
|
||||
|
||||
// factory methods
|
||||
static tile_type* get(type *ty, const tile_shapes_t &shapes);
|
||||
static tile_type* get_same_shapes(type *ty, type *ref);
|
||||
static block_type* get(type *ty, const block_shapes_t &shapes);
|
||||
static block_type* get_same_shapes(type *ty, type *ref);
|
||||
|
||||
private:
|
||||
tile_shapes_t shapes_;
|
||||
block_shapes_t shapes_;
|
||||
};
|
||||
|
||||
class pointer_type: public type {
|
||||
|
@@ -52,7 +52,7 @@ class exp_inst;
|
||||
class log_inst;
|
||||
|
||||
class get_program_id_inst;
|
||||
class get_num_program_inst;
|
||||
class get_num_programs_inst;
|
||||
class atomic_cas_inst;
|
||||
class atomic_exch_inst;
|
||||
class atomic_add_inst;
|
||||
@@ -128,7 +128,7 @@ public:
|
||||
virtual void visit_downcast_inst(downcast_inst*) = 0;
|
||||
|
||||
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
|
||||
virtual void visit_get_num_program_inst(get_num_program_inst*) = 0;
|
||||
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
|
||||
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
|
||||
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
|
||||
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;
|
||||
|
@@ -1,823 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_AST_H_
|
||||
#define _WGTCC_AST_H_
|
||||
|
||||
#include "error.h"
|
||||
#include "token.h"
|
||||
#include "type.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
|
||||
class Visitor;
|
||||
template<typename T> class Evaluator;
|
||||
class AddrEvaluator;
|
||||
class Generator;
|
||||
|
||||
class Scope;
|
||||
class Parser;
|
||||
class ASTNode;
|
||||
class Token;
|
||||
class TokenSequence;
|
||||
|
||||
// Expressions
|
||||
class Expr;
|
||||
class BinaryOp;
|
||||
class UnaryOp;
|
||||
class ConditionalOp;
|
||||
class FuncCall;
|
||||
class TempVar;
|
||||
class Constant;
|
||||
|
||||
class Identifier;
|
||||
class Object;
|
||||
struct Initializer;
|
||||
class Declaration;
|
||||
class Enumerator;
|
||||
|
||||
// Statements
|
||||
class Stmt;
|
||||
class IfStmt;
|
||||
class ForStmt;
|
||||
class JumpStmt;
|
||||
class LabelStmt;
|
||||
class EmptyStmt;
|
||||
class CompoundStmt;
|
||||
class FuncDef;
|
||||
class TranslationUnit;
|
||||
|
||||
|
||||
/*
|
||||
* AST Node
|
||||
*/
|
||||
|
||||
class ASTNode {
|
||||
public:
|
||||
struct Attr{
|
||||
|
||||
enum KindT{
|
||||
MULTIPLEOF,
|
||||
ALIGNED,
|
||||
NOALIAS,
|
||||
READONLY,
|
||||
WRITEONLY,
|
||||
RETUNE,
|
||||
};
|
||||
|
||||
KindT kind;
|
||||
std::vector<Expr*> vals;
|
||||
};
|
||||
using AttrList = std::vector<Attr>;
|
||||
|
||||
public:
|
||||
virtual ~ASTNode() {}
|
||||
virtual void Accept(Visitor* v) = 0;
|
||||
|
||||
protected:
|
||||
ASTNode() {}
|
||||
|
||||
MemPool* pool_ {nullptr};
|
||||
};
|
||||
|
||||
using ExtDecl = ASTNode;
|
||||
|
||||
|
||||
/*
|
||||
* Statements
|
||||
*/
|
||||
|
||||
class Stmt : public ASTNode {
|
||||
public:
|
||||
virtual ~Stmt() {}
|
||||
|
||||
protected:
|
||||
Stmt() {}
|
||||
};
|
||||
|
||||
|
||||
class EmptyStmt : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static EmptyStmt* New();
|
||||
virtual ~EmptyStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
protected:
|
||||
EmptyStmt() {}
|
||||
};
|
||||
|
||||
|
||||
class LabelStmt : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static LabelStmt* New();
|
||||
~LabelStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
std::string Repr() const { return ".L" + std::to_string(tag_); }
|
||||
|
||||
protected:
|
||||
LabelStmt(): tag_(GenTag()) {}
|
||||
|
||||
private:
|
||||
static int GenTag() {
|
||||
static int tag = 0;
|
||||
return ++tag;
|
||||
}
|
||||
|
||||
int tag_; // 使用整型的tag值,而不直接用字符串
|
||||
};
|
||||
|
||||
|
||||
class IfStmt : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
public:
|
||||
static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr);
|
||||
virtual ~IfStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
protected:
|
||||
IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr)
|
||||
: cond_(cond), then_(then), else_(els) {}
|
||||
|
||||
private:
|
||||
Expr* cond_;
|
||||
Stmt* then_;
|
||||
Stmt* else_;
|
||||
};
|
||||
|
||||
class ForStmt: public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
public:
|
||||
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
|
||||
virtual ~ForStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
protected:
|
||||
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
|
||||
: body_(body), init_(init), cond_(cond), step_(step) {}
|
||||
|
||||
private:
|
||||
Stmt* body_;
|
||||
Stmt* init_;
|
||||
Expr* cond_;
|
||||
Expr* step_;
|
||||
};
|
||||
|
||||
class JumpStmt : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static JumpStmt* New(LabelStmt* label);
|
||||
virtual ~JumpStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
void SetLabel(LabelStmt* label) { label_ = label; }
|
||||
|
||||
protected:
|
||||
JumpStmt(LabelStmt* label): label_(label) {}
|
||||
|
||||
private:
|
||||
LabelStmt* label_;
|
||||
};
|
||||
|
||||
|
||||
class ReturnStmt: public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static ReturnStmt* New(Expr* expr);
|
||||
virtual ~ReturnStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
protected:
|
||||
ReturnStmt(::Expr* expr): expr_(expr) {}
|
||||
|
||||
private:
|
||||
::Expr* expr_;
|
||||
};
|
||||
|
||||
|
||||
using StmtList = std::list<Stmt*>;
|
||||
|
||||
class CompoundStmt : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr);
|
||||
virtual ~CompoundStmt() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
StmtList& Stmts() { return stmts_; }
|
||||
::Scope* Scope() { return scope_; }
|
||||
|
||||
protected:
|
||||
CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr)
|
||||
: stmts_(stmts), scope_(scope) {}
|
||||
|
||||
private:
|
||||
StmtList stmts_;
|
||||
::Scope* scope_;
|
||||
};
|
||||
|
||||
|
||||
struct Initializer {
|
||||
Initializer(Type* type,
|
||||
int offset,
|
||||
Expr* expr,
|
||||
unsigned char bitFieldBegin=0,
|
||||
unsigned char bitFieldWidth=0)
|
||||
: type_(type),
|
||||
offset_(offset),
|
||||
bitFieldBegin_(bitFieldBegin),
|
||||
bitFieldWidth_(bitFieldWidth),
|
||||
expr_(expr) {}
|
||||
|
||||
bool operator<(const Initializer& rhs) const;
|
||||
|
||||
// It could be the object it self or, it will be the member
|
||||
// that was initialized
|
||||
Type* type_;
|
||||
int offset_;
|
||||
unsigned char bitFieldBegin_;
|
||||
unsigned char bitFieldWidth_;
|
||||
|
||||
Expr* expr_;
|
||||
};
|
||||
|
||||
|
||||
using InitList = std::set<Initializer>;
|
||||
|
||||
class Declaration: public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static Declaration* New(Object* obj);
|
||||
virtual ~Declaration() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
InitList& Inits() { return inits_; }
|
||||
Object* Obj() { return obj_; }
|
||||
void AddInit(Initializer init);
|
||||
|
||||
protected:
|
||||
Declaration(Object* obj): obj_(obj) {}
|
||||
|
||||
Object* obj_;
|
||||
InitList inits_;
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* Expr
|
||||
* BinaryOp
|
||||
* UnaryOp
|
||||
* ConditionalOp
|
||||
* FuncCall
|
||||
* Constant
|
||||
* Identifier
|
||||
* Object
|
||||
* TempVar
|
||||
*/
|
||||
|
||||
class Expr : public Stmt {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
friend class LValAssigner;
|
||||
|
||||
public:
|
||||
virtual ~Expr() {}
|
||||
::Type* Type() { return type_.GetPtr(); }
|
||||
virtual bool IsLVal() = 0;
|
||||
virtual void TypeChecking() = 0;
|
||||
void EnsureCompatible(const QualType lhs, const QualType rhs) const;
|
||||
void EnsureCompatibleOrVoidPointer(const QualType lhs,
|
||||
const QualType rhs) const;
|
||||
const Token* Tok() const { return tok_; }
|
||||
void SetTok(const Token* tok) { tok_ = tok; }
|
||||
|
||||
static Expr* MayCast(Expr* expr);
|
||||
static Expr* MayCast(Expr* expr, QualType desType);
|
||||
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
|
||||
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
|
||||
|
||||
virtual bool IsNullPointerConstant() const { return false; }
|
||||
bool IsConstQualified() const { return type_.IsConstQualified(); }
|
||||
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
|
||||
bool IsVolatileQualified() const { return type_.IsVolatileQualified(); }
|
||||
|
||||
protected:
|
||||
// You can construct a expression without specifying a type,
|
||||
// then the type should be evaluated in TypeChecking()
|
||||
Expr(const Token* tok, QualType type): tok_(tok), type_(type) {}
|
||||
|
||||
const Token* tok_;
|
||||
QualType type_;
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^'
|
||||
* '=',(复合赋值运算符被拆分为两个运算)
|
||||
* '==', '!=', '<=', '>=',
|
||||
* '&&', '||'
|
||||
* '['(下标运算符), '.'(成员运算符)
|
||||
* ','(逗号运算符),
|
||||
*/
|
||||
class BinaryOp : public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
friend class LValAssigner;
|
||||
friend class Declaration;
|
||||
|
||||
public:
|
||||
static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs);
|
||||
static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs);
|
||||
virtual ~BinaryOp() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
// Member ref operator is a lvalue
|
||||
virtual bool IsLVal() {
|
||||
switch (op_) {
|
||||
case '.': return !Type()->ToArray() && lhs_->IsLVal();
|
||||
case ']': return !Type()->ToArray();
|
||||
case Token::MASKED_DEREF: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
ArithmType* Convert();
|
||||
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
|
||||
|
||||
virtual void TypeChecking();
|
||||
void SubScriptingOpTypeChecking();
|
||||
void MemberRefOpTypeChecking();
|
||||
void MultiOpTypeChecking();
|
||||
void AdditiveOpTypeChecking();
|
||||
void ShiftOpTypeChecking();
|
||||
void RangeOpTypeChecking();
|
||||
void MatmulOpTypeChecking();
|
||||
void MaskedDerefOpTypeChecking();
|
||||
void RelationalOpTypeChecking();
|
||||
void EqualityOpTypeChecking();
|
||||
void BitwiseOpTypeChecking();
|
||||
void LogicalOpTypeChecking();
|
||||
void AssignOpTypeChecking();
|
||||
void CommaOpTypeChecking();
|
||||
|
||||
protected:
|
||||
BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs)
|
||||
: Expr(tok, nullptr), op_(op) {
|
||||
lhs_ = lhs, rhs_ = rhs;
|
||||
if (op != '.') {
|
||||
lhs_ = MayCast(lhs);
|
||||
rhs_ = MayCast(rhs);
|
||||
}
|
||||
}
|
||||
|
||||
int op_;
|
||||
Expr* lhs_;
|
||||
Expr* rhs_;
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* Unary Operator:
|
||||
* '++' (prefix/postfix)
|
||||
* '--' (prefix/postfix)
|
||||
* '&' (ADDR)
|
||||
* '*' (DEREF)
|
||||
* '+' (PLUS)
|
||||
* '-' (MINUS)
|
||||
* '~'
|
||||
* '!'
|
||||
* CAST // like (int)3
|
||||
*/
|
||||
class UnaryOp : public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
friend class LValAssigner;
|
||||
|
||||
public:
|
||||
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
|
||||
virtual ~UnaryOp() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal();
|
||||
::Type *Convert();
|
||||
static int encodeRed(int ax, int tag);
|
||||
static void decodeRed(int info, int& ax, int& tag);
|
||||
void TypeChecking();
|
||||
void IncDecOpTypeChecking();
|
||||
void AddrOpTypeChecking();
|
||||
void DerefOpTypeChecking();
|
||||
void ReduceOpTypeChecking();
|
||||
void UnaryArithmOpTypeChecking();
|
||||
void BitcastOpTypeChecking();
|
||||
void CastOpTypeChecking();
|
||||
void IntrinsicOpTypeChecking();
|
||||
|
||||
protected:
|
||||
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
|
||||
: Expr(operand->Tok(), type), op_(op), info_(info) {
|
||||
operand_ = operand;
|
||||
if (op_ != Token::CAST && op_ != Token::ADDR) {
|
||||
operand_ = MayCast(operand);
|
||||
}
|
||||
}
|
||||
|
||||
int op_;
|
||||
int info_;
|
||||
Expr* operand_;
|
||||
};
|
||||
|
||||
class TransOp: public Expr {
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
using PermInt = std::vector<int>;
|
||||
|
||||
public:
|
||||
static TransOp* New(const PermInt& perm, Expr* operand);
|
||||
const PermInt& getPerm() const { return perm_; }
|
||||
void Accept(Visitor* v);
|
||||
bool IsLVal() { return false; }
|
||||
void TypeChecking();
|
||||
|
||||
protected:
|
||||
TransOp(const PermInt& perm, Expr* operand)
|
||||
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
|
||||
|
||||
private:
|
||||
Expr* operand_;
|
||||
PermInt perm_;
|
||||
};
|
||||
|
||||
|
||||
// cond ? true : false
|
||||
class ConditionalOp : public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static ConditionalOp* New(const Token* tok,
|
||||
Expr* cond, Expr* exprTrue, Expr* exprFalse);
|
||||
virtual ~ConditionalOp() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal() { return false; }
|
||||
ArithmType* Convert();
|
||||
virtual void TypeChecking();
|
||||
|
||||
protected:
|
||||
ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse)
|
||||
: Expr(cond->Tok(), nullptr), cond_(MayCast(cond)),
|
||||
exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {}
|
||||
|
||||
private:
|
||||
Expr* cond_;
|
||||
Expr* exprTrue_;
|
||||
Expr* exprFalse_;
|
||||
};
|
||||
|
||||
|
||||
class FuncCall : public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
using ArgList = std::vector<Expr*>;
|
||||
|
||||
public:
|
||||
static FuncCall* New(Expr* designator, const ArgList& args);
|
||||
~FuncCall() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
|
||||
// A function call is ofcourse not lvalue
|
||||
virtual bool IsLVal() { return false; }
|
||||
ArgList* Args() { return &args_; }
|
||||
Expr* Designator() { return designator_; }
|
||||
const std::string& Name() const { return tok_->str_; }
|
||||
::FuncType* FuncType() { return designator_->Type()->ToFunc(); }
|
||||
virtual void TypeChecking();
|
||||
|
||||
protected:
|
||||
FuncCall(Expr* designator, const ArgList& args)
|
||||
: Expr(designator->Tok(), nullptr),
|
||||
designator_(designator), args_(args) {}
|
||||
|
||||
Expr* designator_;
|
||||
ArgList args_;
|
||||
};
|
||||
|
||||
|
||||
class Constant: public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static Constant* New(const Token* tok, int tag, long val);
|
||||
static Constant* New(const Token* tok, int tag, double val);
|
||||
static Constant* New(const Token* tok, int tag, const std::string* val);
|
||||
~Constant() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal() { return false; }
|
||||
virtual void TypeChecking() {}
|
||||
|
||||
long IVal() const { return ival_; }
|
||||
double FVal() const { return fval_; }
|
||||
const std::string* SVal() const { return sval_; }
|
||||
std::string SValRepr() const;
|
||||
std::string Repr() const { return std::string(".LC") + std::to_string(id_); }
|
||||
|
||||
protected:
|
||||
Constant(const Token* tok, QualType type, long val)
|
||||
: Expr(tok, type), ival_(val) {}
|
||||
Constant(const Token* tok, QualType type, double val)
|
||||
: Expr(tok, type), fval_(val) {}
|
||||
Constant(const Token* tok, QualType type, const std::string* val)
|
||||
: Expr(tok, type), sval_(val) {}
|
||||
|
||||
union {
|
||||
long ival_;
|
||||
double fval_;
|
||||
struct {
|
||||
long id_;
|
||||
const std::string* sval_;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
class TempVar : public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static TempVar* New(QualType type);
|
||||
virtual ~TempVar() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal() { return true; }
|
||||
virtual void TypeChecking() {}
|
||||
|
||||
protected:
|
||||
TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {}
|
||||
|
||||
private:
|
||||
static int GenTag() {
|
||||
static int tag = 0;
|
||||
return ++tag;
|
||||
}
|
||||
|
||||
int tag_;
|
||||
};
|
||||
|
||||
|
||||
enum Linkage {
|
||||
L_NONE,
|
||||
L_EXTERNAL,
|
||||
L_INTERNAL,
|
||||
};
|
||||
|
||||
|
||||
class Identifier: public Expr {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
friend class LValAssigner;
|
||||
|
||||
public:
|
||||
static Identifier* New(const Token* tok, QualType type, Linkage linkage, const AttrList& attrList={});
|
||||
virtual ~Identifier() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal() { return false; }
|
||||
virtual Object* ToObject() { return nullptr; }
|
||||
virtual Enumerator* ToEnumerator() { return nullptr; }
|
||||
|
||||
// An identifer can be:
|
||||
// object, sturct/union/enum tag, typedef name, function, label.
|
||||
Identifier* ToTypeName() {
|
||||
// A typename has no linkage
|
||||
// And a function has external or internal linkage
|
||||
if (ToObject() || ToEnumerator() || linkage_ != L_NONE)
|
||||
return nullptr;
|
||||
return this;
|
||||
}
|
||||
virtual const std::string Name() const { return tok_->str_; }
|
||||
enum Linkage Linkage() const { return linkage_; }
|
||||
void SetLinkage(enum Linkage linkage) { linkage_ = linkage; }
|
||||
virtual void TypeChecking() {}
|
||||
|
||||
protected:
|
||||
Identifier(const Token* tok, QualType type, enum Linkage linkage, const AttrList& attrList={})
|
||||
: Expr(tok, type), linkage_(linkage), attrList_(attrList) {}
|
||||
|
||||
// An identifier has property linkage
|
||||
enum Linkage linkage_;
|
||||
AttrList attrList_;
|
||||
};
|
||||
|
||||
|
||||
class Enumerator: public Identifier {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static Enumerator* New(const Token* tok, int val);
|
||||
virtual ~Enumerator() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual Enumerator* ToEnumerator() { return this; }
|
||||
int Val() const { return cons_->IVal(); }
|
||||
|
||||
protected:
|
||||
Enumerator(const Token* tok, int val)
|
||||
: Identifier(tok, ArithmType::New(T_INT), L_NONE),
|
||||
cons_(Constant::New(tok, T_INT, (long)val)) {}
|
||||
|
||||
Constant* cons_;
|
||||
};
|
||||
|
||||
|
||||
class Object : public Identifier {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
friend class LValAssigner;
|
||||
|
||||
public:
|
||||
static Object* New(const Token* tok,
|
||||
QualType type,
|
||||
int storage=0,
|
||||
enum Linkage linkage=L_NONE,
|
||||
unsigned char bitFieldBegin=0,
|
||||
unsigned char bitFieldWidth=0,
|
||||
const AttrList& attrList={});
|
||||
static Object* NewAnony(const Token* tok,
|
||||
QualType type,
|
||||
int storage=0,
|
||||
enum Linkage linkage=L_NONE,
|
||||
unsigned char bitFieldBegin=0,
|
||||
unsigned char bitFieldWidth=0,
|
||||
const AttrList& attrList={});
|
||||
~Object() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual Object* ToObject() { return this; }
|
||||
virtual bool IsLVal() {
|
||||
// TODO(wgtdkp): not all object is lval?
|
||||
return true;
|
||||
}
|
||||
bool IsStatic() const {
|
||||
return (Storage() & S_STATIC) || (Linkage() != L_NONE);
|
||||
}
|
||||
int Storage() const { return storage_; }
|
||||
void SetStorage(int storage) { storage_ = storage; }
|
||||
int Align() const { return align_; }
|
||||
void SetAlign(int align) {
|
||||
assert(align > 0);
|
||||
// Allowing reduce alignment to implement __attribute__((packed))
|
||||
//if (align < align_)
|
||||
// Error(this, "alignment specifier cannot reduce alignment");
|
||||
align_ = align;
|
||||
}
|
||||
int Offset() const { return offset_; }
|
||||
void SetOffset(int offset) { offset_ = offset; }
|
||||
Declaration* Decl() { return decl_; }
|
||||
void SetDecl(Declaration* decl) { decl_ = decl; }
|
||||
const AttrList& GetAttrList() const { return attrList_; }
|
||||
unsigned char BitFieldBegin() const { return bitFieldBegin_; }
|
||||
unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; }
|
||||
unsigned char BitFieldWidth() const { return bitFieldWidth_; }
|
||||
static unsigned long BitFieldMask(Object* bitField) {
|
||||
return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_);
|
||||
}
|
||||
static unsigned long BitFieldMask(unsigned char begin, unsigned char width) {
|
||||
auto end = begin + width;
|
||||
return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin;
|
||||
}
|
||||
|
||||
bool HasInit() const { return decl_ && decl_->Inits().size(); }
|
||||
bool Anonymous() const { return anonymous_; }
|
||||
virtual const std::string Name() const { return Identifier::Name(); }
|
||||
std::string Repr() const {
|
||||
assert(IsStatic() || anonymous_);
|
||||
if (anonymous_)
|
||||
return "anonymous." + std::to_string(id_);
|
||||
if (linkage_ == L_NONE)
|
||||
return Name() + "." + std::to_string(id_);
|
||||
return Name();
|
||||
}
|
||||
|
||||
protected:
|
||||
Object(const Token* tok,
|
||||
QualType type,
|
||||
int storage=0,
|
||||
enum Linkage linkage=L_NONE,
|
||||
unsigned char bitFieldBegin=0,
|
||||
unsigned char bitFieldWidth=0,
|
||||
const AttrList& attrList={})
|
||||
: Identifier(tok, type, linkage),
|
||||
storage_(storage),
|
||||
offset_(0),
|
||||
align_(type->Align()),
|
||||
decl_(nullptr),
|
||||
bitFieldBegin_(bitFieldBegin),
|
||||
bitFieldWidth_(bitFieldWidth),
|
||||
anonymous_(false),
|
||||
attrList_(attrList){}
|
||||
|
||||
private:
|
||||
int storage_;
|
||||
int offset_;
|
||||
int align_;
|
||||
|
||||
Declaration* decl_;
|
||||
|
||||
unsigned char bitFieldBegin_;
|
||||
// 0 means it's not a bitfield
|
||||
unsigned char bitFieldWidth_;
|
||||
|
||||
bool anonymous_;
|
||||
long id_ {0};
|
||||
|
||||
AttrList attrList_;
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* Declaration
|
||||
*/
|
||||
|
||||
class FuncDef : public ExtDecl {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
using ParamList = std::vector<Object*>;
|
||||
|
||||
public:
|
||||
static FuncDef* New(Identifier* ident, LabelStmt* retLabel);
|
||||
virtual ~FuncDef() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
::FuncType* FuncType() { return ident_->Type()->ToFunc(); }
|
||||
CompoundStmt* Body() { return body_; }
|
||||
void SetBody(CompoundStmt* body) { body_ = body; }
|
||||
std::string Name() const { return ident_->Name(); }
|
||||
enum Linkage Linkage() { return ident_->Linkage(); }
|
||||
|
||||
protected:
|
||||
FuncDef(Identifier* ident, LabelStmt* retLabel)
|
||||
: ident_(ident), retLabel_(retLabel) {}
|
||||
|
||||
private:
|
||||
Identifier* ident_;
|
||||
LabelStmt* retLabel_;
|
||||
CompoundStmt* body_;
|
||||
};
|
||||
|
||||
|
||||
using ExtDeclList = std::list<ExtDecl*>;
|
||||
|
||||
class TranslationUnit : public ASTNode {
|
||||
template<typename T> friend class Evaluator;
|
||||
friend class AddrEvaluator;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
static TranslationUnit* New() { return new TranslationUnit();}
|
||||
virtual ~TranslationUnit() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); }
|
||||
ExtDeclList& ExtDecls() { return extDecls_; }
|
||||
const ExtDeclList& ExtDecls() const { return extDecls_; }
|
||||
|
||||
private:
|
||||
TranslationUnit() {}
|
||||
|
||||
ExtDeclList extDecls_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,167 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_CODE_GEN_H_
|
||||
#define _WGTCC_CODE_GEN_H_
|
||||
|
||||
#include "ast.h"
|
||||
#include "visitor.h"
|
||||
#include <stack>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class value;
|
||||
class module;
|
||||
class type;
|
||||
class context;
|
||||
class builder;
|
||||
class attribute;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
using namespace triton;
|
||||
|
||||
class Parser;
|
||||
struct Addr;
|
||||
template<> class Evaluator<Addr>;
|
||||
struct StaticInitializer;
|
||||
class LValAssigner;
|
||||
|
||||
using TypeList = std::vector<Type*>;
|
||||
using LocationList = std::vector<std::string>;
|
||||
using StaticInitList = std::vector<StaticInitializer>;
|
||||
|
||||
// Error
|
||||
inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); }
|
||||
inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); }
|
||||
|
||||
class Generator: public Visitor {
|
||||
friend class Evaluator<Addr>;
|
||||
friend class LValAssigner;
|
||||
|
||||
protected:
|
||||
struct scope {
|
||||
std::map<std::string, ir::type*> types;
|
||||
std::map<std::string, ir::value*> values;
|
||||
};
|
||||
|
||||
void set_ret(ir::value* value);
|
||||
ir::value *GenUnaryMinus(ir::value* arg);
|
||||
ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc);
|
||||
|
||||
public:
|
||||
Generator(Parser* parser) : parser_(parser) {}
|
||||
|
||||
void Visit(ASTNode* node) { node->Accept(this); }
|
||||
void VisitExpr(Expr* expr) { expr->Accept(this); }
|
||||
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
|
||||
|
||||
// Expression
|
||||
void VisitBinaryOp(BinaryOp* binaryOp);
|
||||
void VisitUnaryOp(UnaryOp* unaryOp);
|
||||
void VisitTransOp(TransOp* transOp);
|
||||
void VisitConditionalOp(ConditionalOp* condOp);
|
||||
void VisitFuncCall(FuncCall* funcCall);
|
||||
void VisitObject(Object* obj);
|
||||
void VisitEnumerator(Enumerator* enumer);
|
||||
void VisitIdentifier(Identifier* ident);
|
||||
void VisitConstant(Constant* cons);
|
||||
void VisitTempVar(TempVar* tempVar);
|
||||
|
||||
// Statement
|
||||
void VisitDeclaration(Declaration* init);
|
||||
void VisitEmptyStmt(EmptyStmt* emptyStmt);
|
||||
void VisitIfStmt(IfStmt* ifStmt);
|
||||
void VisitForStmt(ForStmt* ifStmt);
|
||||
void VisitJumpStmt(JumpStmt* jumpStmt);
|
||||
void VisitReturnStmt(ReturnStmt* returnStmt);
|
||||
void VisitLabelStmt(LabelStmt* labelStmt);
|
||||
void VisitCompoundStmt(CompoundStmt* compoundStmt);
|
||||
|
||||
void VisitFuncDef(FuncDef* funcDef);
|
||||
void VisitTranslationUnit(TranslationUnit* unit);
|
||||
|
||||
void Gen(ir::module *mod);
|
||||
|
||||
protected:
|
||||
// Triton-IR attributes
|
||||
ir::attribute GenIRAttr(ASTNode::Attr attr);
|
||||
|
||||
// Triton-IR metadata
|
||||
void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs);
|
||||
|
||||
// Triton-IR values
|
||||
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
|
||||
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
|
||||
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
|
||||
ir::value* GenSemCastOp(ir::value* op, ir::type* type);
|
||||
ir::value* GenBitCastOp(ir::value* src, ir::type* dst_ty);
|
||||
|
||||
// Triton-IR types
|
||||
static ir::type* GenIRType(::Type* type, ir::context &ctx);
|
||||
static ir::type* GenIRArithmType(ArithmType* type, ir::context& ctx);
|
||||
static ir::type* GenIRArrayType(ArrayType* type, ir::context& ctx);
|
||||
static ir::type* GenIRTileType(TileType* type, ir::context& ctx);
|
||||
static ir::type* GenIRFuncType(FuncType* type, ir::context& ctx);
|
||||
static ir::type* GenIRPointerType(PointerType* type, ir::context& ctx);
|
||||
static ir::type* GenIRStructType(StructType* type, ir::context& ctx);
|
||||
void AllocObjects(Scope* scope, const FuncDef::ParamList& params=FuncDef::ParamList());
|
||||
|
||||
// SSA
|
||||
void pushScope();
|
||||
void popScope();
|
||||
|
||||
private:
|
||||
Parser* parser_;
|
||||
ir::value* ret_;
|
||||
ir::builder* bld_;
|
||||
ir::context* ctx_;
|
||||
ir::module* mod_;
|
||||
|
||||
private:
|
||||
// std::stack<scope> scopes_;
|
||||
LValAssigner* assign_;
|
||||
};
|
||||
|
||||
|
||||
class LValAssigner: public Visitor {
|
||||
public:
|
||||
LValAssigner(Generator* gen): gen_(gen) {}
|
||||
|
||||
// Expression
|
||||
void VisitBinaryOp(BinaryOp* binaryOp);
|
||||
void VisitUnaryOp(UnaryOp* unaryOp);
|
||||
void VisitObject(Object* obj);
|
||||
void VisitIdentifier(Identifier* ident);
|
||||
|
||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); }
|
||||
void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); }
|
||||
void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); }
|
||||
void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); }
|
||||
void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); }
|
||||
void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); }
|
||||
void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); }
|
||||
void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); }
|
||||
void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); }
|
||||
void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); }
|
||||
void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); }
|
||||
void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); }
|
||||
void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); }
|
||||
void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); }
|
||||
void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); }
|
||||
void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); }
|
||||
|
||||
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
|
||||
rhs_ = rhs;
|
||||
expr->Accept(this);
|
||||
return ret_;
|
||||
}
|
||||
|
||||
private:
|
||||
ir::value* ret_;
|
||||
ir::value* rhs_;
|
||||
Generator* gen_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,164 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_CPP_H_
|
||||
#define _WGTCC_CPP_H_
|
||||
|
||||
#include "scanner.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
|
||||
class Macro;
|
||||
struct CondDirective;
|
||||
|
||||
using MacroMap = std::map<std::string, Macro>;
|
||||
using ParamList = std::list<std::string>;
|
||||
using ParamMap = std::map<std::string, TokenSequence>;
|
||||
using PPCondStack = std::stack<CondDirective>;
|
||||
using PathList = std::list<std::string>;
|
||||
|
||||
|
||||
class Macro {
|
||||
public:
|
||||
Macro(const TokenSequence& repSeq, bool preDef=false)
|
||||
: funcLike_(false), variadic_(false),
|
||||
preDef_(preDef), repSeq_(repSeq) {}
|
||||
|
||||
Macro(bool variadic, ParamList& params,
|
||||
TokenSequence& repSeq, bool preDef=false)
|
||||
: funcLike_(true), variadic_(variadic), preDef_(preDef),
|
||||
params_(params), repSeq_(repSeq) {}
|
||||
|
||||
~Macro() {}
|
||||
bool FuncLike() { return funcLike_; }
|
||||
bool ObjLike() { return !FuncLike(); }
|
||||
bool Variadic() { return variadic_; }
|
||||
bool PreDef() { return preDef_; }
|
||||
ParamList& Params() { return params_; }
|
||||
TokenSequence RepSeq(const std::string* filename, unsigned line);
|
||||
|
||||
private:
|
||||
bool funcLike_;
|
||||
bool variadic_;
|
||||
bool preDef_;
|
||||
ParamList params_;
|
||||
TokenSequence repSeq_;
|
||||
};
|
||||
|
||||
|
||||
struct CondDirective {
|
||||
int tag_;
|
||||
bool enabled_;
|
||||
bool cond_;
|
||||
};
|
||||
|
||||
|
||||
class Preprocessor {
|
||||
public:
|
||||
Preprocessor(const std::string* str, bool isSrc = true)
|
||||
: curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) {
|
||||
if(isSrc)
|
||||
fSrc_ = str;
|
||||
else
|
||||
fName_ = str;
|
||||
// Add predefined
|
||||
Init();
|
||||
}
|
||||
|
||||
|
||||
~Preprocessor() {}
|
||||
void Finalize(TokenSequence os);
|
||||
void Process(TokenSequence& os);
|
||||
void Expand(TokenSequence& os, TokenSequence is, bool inCond=false);
|
||||
void Subst(TokenSequence& os, TokenSequence is,
|
||||
bool leadingWS, const HideSet& hs, ParamMap& params);
|
||||
void Glue(TokenSequence& os, TokenSequence is);
|
||||
void Glue(TokenSequence& os, const Token* tok);
|
||||
const Token* Stringize(TokenSequence is);
|
||||
void Stringize(std::string& str, TokenSequence is);
|
||||
const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap);
|
||||
int GetDirective(TokenSequence& is);
|
||||
const Token* EvalDefOp(TokenSequence& is);
|
||||
void ReplaceIdent(TokenSequence& is);
|
||||
void ParseDirective(TokenSequence& os, TokenSequence& is, int directive);
|
||||
void ParseIf(TokenSequence ls);
|
||||
void ParseIfdef(TokenSequence ls);
|
||||
void ParseIfndef(TokenSequence ls);
|
||||
void ParseElif(TokenSequence ls);
|
||||
void ParseElse(TokenSequence ls);
|
||||
void ParseEndif(TokenSequence ls);
|
||||
void ParseInclude(TokenSequence& is, TokenSequence ls);
|
||||
void ParseDef(TokenSequence ls);
|
||||
void ParseUndef(TokenSequence ls);
|
||||
void ParseLine(TokenSequence ls);
|
||||
void ParseError(TokenSequence ls);
|
||||
void ParsePragma(TokenSequence ls);
|
||||
void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename);
|
||||
void IncludeFile(TokenSequence& is, const std::string* filename);
|
||||
bool ParseIdentList(ParamList& params, TokenSequence& is);
|
||||
|
||||
|
||||
Macro* FindMacro(const std::string& name) {
|
||||
auto res = macroMap_.find(name);
|
||||
if (res == macroMap_.end())
|
||||
return nullptr;
|
||||
return &res->second;
|
||||
}
|
||||
|
||||
void AddMacro(const std::string& name,
|
||||
std::string* text, bool preDef=false);
|
||||
|
||||
void AddMacro(const std::string& name, const Macro& macro) {
|
||||
auto res = macroMap_.find(name);
|
||||
if (res != macroMap_.end()) {
|
||||
// TODO(wgtdkp): give warning
|
||||
macroMap_.erase(res);
|
||||
}
|
||||
macroMap_.insert(std::make_pair(name, macro));
|
||||
}
|
||||
|
||||
void RemoveMacro(const std::string& name) {
|
||||
auto res = macroMap_.find(name);
|
||||
if (res == macroMap_.end())
|
||||
return;
|
||||
if(res->second.PreDef()) // Cannot undef predefined macro
|
||||
return;
|
||||
macroMap_.erase(res);
|
||||
}
|
||||
|
||||
std::string* SearchFile(const std::string& name,
|
||||
const bool libHeader,
|
||||
bool next,
|
||||
const std::string& curPath);
|
||||
|
||||
void AddSearchPath(std::string path);
|
||||
void HandleTheFileMacro(TokenSequence& os, const Token* macro);
|
||||
void HandleTheLineMacro(TokenSequence& os, const Token* macro);
|
||||
void UpdateFirstTokenLine(TokenSequence ts);
|
||||
|
||||
bool NeedExpand() const {
|
||||
if (ppCondStack_.empty())
|
||||
return true;
|
||||
auto top = ppCondStack_.top();
|
||||
return top.enabled_ && top.cond_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init();
|
||||
|
||||
PPCondStack ppCondStack_;
|
||||
unsigned curLine_;
|
||||
unsigned lineLine_;
|
||||
bool curCond_;
|
||||
|
||||
MacroMap macroMap_;
|
||||
PathList searchPaths_;
|
||||
const std::string* fName_;
|
||||
const std::string* fSrc_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,22 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_ENCODING_H_
|
||||
#define _WGTCC_ENCODING_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
|
||||
enum class Encoding {
|
||||
NONE,
|
||||
CHAR16,
|
||||
CHAR32,
|
||||
UTF8,
|
||||
WCHAR
|
||||
};
|
||||
|
||||
|
||||
void ConvertToUTF16(std::string& str);
|
||||
void ConvertToUTF32(std::string& str);
|
||||
void AppendUCN(std::string& str, int c);
|
||||
|
||||
#endif
|
@@ -1,17 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_ERROR_H_
|
||||
#define _WGTCC_ERROR_H_
|
||||
|
||||
|
||||
struct SourceLocation;
|
||||
class Token;
|
||||
class Expr;
|
||||
|
||||
|
||||
[[noreturn]] void Error(const char* format, ...);
|
||||
[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...);
|
||||
[[noreturn]] void Error(const Token* tok, const char* format, ...);
|
||||
[[noreturn]] void Error(const Expr* expr, const char* format, ...);
|
||||
|
||||
#endif
|
@@ -1,130 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_EVALUATOR_H_
|
||||
#define _WGTCC_EVALUATOR_H_
|
||||
|
||||
#include "ast.h"
|
||||
#include "error.h"
|
||||
#include "visitor.h"
|
||||
|
||||
|
||||
class Expr;
|
||||
|
||||
template<typename T>
|
||||
class Evaluator: public Visitor {
|
||||
public:
|
||||
Evaluator() {}
|
||||
|
||||
virtual ~Evaluator() {}
|
||||
|
||||
virtual void VisitBinaryOp(BinaryOp* binary);
|
||||
virtual void VisitUnaryOp(UnaryOp* unary);
|
||||
virtual void VisitConditionalOp(ConditionalOp* cond);
|
||||
|
||||
virtual void VisitFuncCall(FuncCall* funcCall) {
|
||||
Error(funcCall, "expect constant expression");
|
||||
}
|
||||
virtual void VisitEnumerator(Enumerator* enumer) {
|
||||
val_ = static_cast<T>(enumer->Val());
|
||||
}
|
||||
virtual void VisitIdentifier(Identifier* ident) {
|
||||
Error(ident, "expect constant expression");
|
||||
}
|
||||
virtual void VisitTransOp(TransOp* trans) {
|
||||
Error(trans, "expect constant expression");
|
||||
}
|
||||
virtual void VisitObject(Object* obj) {
|
||||
Error(obj, "expect constant expression");
|
||||
}
|
||||
virtual void VisitConstant(Constant* cons) {
|
||||
if (cons->Type()->IsFloat()) {
|
||||
val_ = static_cast<T>(cons->FVal());
|
||||
} else if (cons->Type()->IsInteger()) {
|
||||
val_ = static_cast<T>(cons->IVal());
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
|
||||
|
||||
// We may should assert here
|
||||
virtual void VisitDeclaration(Declaration* init) {}
|
||||
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
||||
virtual void VisitForStmt(ForStmt* forStmt) {}
|
||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
||||
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
|
||||
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
|
||||
virtual void VisitFuncDef(FuncDef* funcDef) {}
|
||||
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
|
||||
|
||||
T Eval(Expr* expr) {
|
||||
expr->Accept(this);
|
||||
return val_;
|
||||
}
|
||||
|
||||
private:
|
||||
T val_;
|
||||
};
|
||||
|
||||
|
||||
struct Addr {
|
||||
std::string label_;
|
||||
int offset_;
|
||||
};
|
||||
|
||||
template<>
|
||||
class Evaluator<Addr>: public Visitor {
|
||||
public:
|
||||
Evaluator<Addr>() {}
|
||||
virtual ~Evaluator<Addr>() {}
|
||||
virtual void VisitBinaryOp(BinaryOp* binary);
|
||||
virtual void VisitUnaryOp(UnaryOp* unary);
|
||||
virtual void VisitConditionalOp(ConditionalOp* cond);
|
||||
|
||||
virtual void VisitFuncCall(FuncCall* funcCall) {
|
||||
Error(funcCall, "expect constant expression");
|
||||
}
|
||||
virtual void VisitTransOp(TransOp* trans) {
|
||||
Error(trans, "expect constant expression");
|
||||
}
|
||||
virtual void VisitEnumerator(Enumerator* enumer) {
|
||||
addr_.offset_ = enumer->Val();
|
||||
}
|
||||
virtual void VisitIdentifier(Identifier* ident) {
|
||||
addr_.label_ = ident->Name();
|
||||
addr_.offset_ = 0;
|
||||
}
|
||||
virtual void VisitObject(Object* obj) {
|
||||
if (!obj->IsStatic()) {
|
||||
Error(obj, "expect static object");
|
||||
}
|
||||
addr_.label_ = obj->Repr();
|
||||
addr_.offset_ = 0;
|
||||
}
|
||||
virtual void VisitConstant(Constant* cons);
|
||||
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
|
||||
|
||||
// We may should assert here
|
||||
virtual void VisitDeclaration(Declaration* init) {}
|
||||
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
||||
virtual void VisitForStmt(ForStmt* forStmt) {}
|
||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
||||
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
|
||||
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
|
||||
virtual void VisitFuncDef(FuncDef* funcDef) {}
|
||||
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
|
||||
|
||||
Addr Eval(Expr* expr) {
|
||||
expr->Accept(this);
|
||||
return addr_;
|
||||
}
|
||||
|
||||
private:
|
||||
Addr addr_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,103 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_MEM_POOL_H_
|
||||
#define _WGTCC_MEM_POOL_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
|
||||
class MemPool {
|
||||
public:
|
||||
MemPool(): allocated_(0) {}
|
||||
virtual ~MemPool() {}
|
||||
MemPool(const MemPool& other) = delete;
|
||||
MemPool& operator=(const MemPool& other) = delete;
|
||||
virtual void* Alloc() = 0;
|
||||
virtual void Free(void* addr) = 0;
|
||||
virtual void Clear() = 0;
|
||||
|
||||
protected:
|
||||
size_t allocated_;
|
||||
};
|
||||
|
||||
|
||||
template <class T>
|
||||
class MemPoolImp: public MemPool {
|
||||
public:
|
||||
MemPoolImp() : root_(nullptr) {}
|
||||
virtual ~MemPoolImp() {}
|
||||
MemPoolImp(const MemPool& other) = delete;
|
||||
MemPoolImp& operator=(MemPool& other) = delete;
|
||||
virtual void* Alloc();
|
||||
virtual void Free(void* addr);
|
||||
virtual void Clear();
|
||||
|
||||
private:
|
||||
enum {
|
||||
COUNT = (4 * 1024) / sizeof(T)
|
||||
};
|
||||
|
||||
union Chunk {
|
||||
Chunk* next_;
|
||||
char mem_[sizeof(T)];
|
||||
};
|
||||
|
||||
struct Block {
|
||||
Block() {
|
||||
for (size_t i = 0; i < COUNT - 1; ++i)
|
||||
chunks_[i].next_ = &chunks_[i+1];
|
||||
chunks_[COUNT-1].next_ = nullptr;
|
||||
}
|
||||
Chunk chunks_[COUNT];
|
||||
};
|
||||
|
||||
std::vector<Block*> blocks_;
|
||||
Chunk* root_;
|
||||
};
|
||||
|
||||
|
||||
template <class T>
|
||||
void* MemPoolImp<T>::Alloc() {
|
||||
if (nullptr == root_) { // 空间不够,需要分配空间
|
||||
auto block = new Block();
|
||||
root_ = block->chunks_;
|
||||
// 如果blocks实现为std::list, 那么push_back实际的overhead更大
|
||||
// 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead),
|
||||
// 仍然倾向于使用std::vector,
|
||||
// 当然std::vector的指数级capacity增长会造成内存浪费。
|
||||
blocks_.push_back(block);
|
||||
}
|
||||
|
||||
auto ret = root_;
|
||||
root_ = root_->next_;
|
||||
|
||||
++allocated_;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
void MemPoolImp<T>::Free(void* addr) {
|
||||
if (nullptr == addr)
|
||||
return;
|
||||
|
||||
auto chunk = static_cast<Chunk*>(addr);
|
||||
chunk->next_ = root_;
|
||||
root_ = chunk;
|
||||
|
||||
--allocated_;
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
void MemPoolImp<T>::Clear() {
|
||||
for (auto block: blocks_)
|
||||
delete block;
|
||||
|
||||
blocks_.resize(0);
|
||||
root_ = nullptr;
|
||||
allocated_ = 0;
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,260 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _PARSER_H_
|
||||
#define _PARSER_H_
|
||||
|
||||
#include "ast.h"
|
||||
#include "encoding.h"
|
||||
#include "error.h"
|
||||
#include "mem_pool.h"
|
||||
#include "scope.h"
|
||||
#include "token.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
|
||||
|
||||
class Preprocessor;
|
||||
|
||||
struct DeclInfo {
|
||||
DeclInfo(const Token* _tok,
|
||||
QualType _type,
|
||||
ASTNode::AttrList _attrs = {})
|
||||
: tok(_tok), type(_type), attrs(_attrs) {}
|
||||
|
||||
const Token* tok;
|
||||
QualType type;
|
||||
ASTNode::AttrList attrs;
|
||||
};
|
||||
|
||||
|
||||
class Parser {
|
||||
using LiteralList = std::vector<Constant*>;
|
||||
using StaticObjectList = std::vector<Object*>;
|
||||
using CaseLabelList = std::vector<std::pair<Constant*, LabelStmt*>>;
|
||||
using LabelJumpList = std::list<std::pair<const Token*, JumpStmt*>>;
|
||||
using LabelMap = std::map<std::string, LabelStmt*>;
|
||||
friend class Generator;
|
||||
|
||||
public:
|
||||
explicit Parser(TokenSequence& ts)
|
||||
: unit_(TranslationUnit::New()),
|
||||
ts_(ts),
|
||||
externalSymbols_(new Scope(nullptr, S_BLOCK)),
|
||||
errTok_(nullptr),
|
||||
curScope_(new Scope(nullptr, S_FILE)),
|
||||
curFunc_(nullptr),
|
||||
breakDest_(nullptr),
|
||||
continueDest_(nullptr),
|
||||
caseLabels_(nullptr),
|
||||
defaultLabel_(nullptr) {
|
||||
ts_.SetParser(this);
|
||||
}
|
||||
|
||||
~Parser() {}
|
||||
|
||||
Constant* ParseConstant(const Token* tok);
|
||||
Constant* ParseFloat(const Token* tok);
|
||||
Constant* ParseInteger(const Token* tok);
|
||||
Constant* ParseCharacter(const Token* tok);
|
||||
Encoding ParseLiteral(std::string& str, const Token* tok);
|
||||
Constant* ConcatLiterals(const Token* tok);
|
||||
Expr* ParseGeneric();
|
||||
|
||||
void Parse();
|
||||
void ParseTranslationUnit();
|
||||
FuncDef* ParseFuncDef(Identifier* ident);
|
||||
|
||||
|
||||
// Expressions
|
||||
Expr* ParseExpr();
|
||||
Expr* ParsePrimaryExpr();
|
||||
QualType TryCompoundLiteral();
|
||||
Object* ParseCompoundLiteral(QualType type);
|
||||
Expr* ParsePostfixExpr();
|
||||
Expr* ParsePostfixExprTail(Expr* primExpr);
|
||||
Expr* ParseSubScripting(Expr* pointer);
|
||||
BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs);
|
||||
UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand);
|
||||
FuncCall* ParseFuncCall(Expr* caller);
|
||||
|
||||
Expr* ParseUnaryExpr();
|
||||
Constant* ParseSizeof();
|
||||
Constant* ParseAlignof();
|
||||
UnaryOp* ParsePrefixIncDec(const Token* tok);
|
||||
UnaryOp* ParseUnaryIntrinsicOp(int op);
|
||||
UnaryOp* ParseUnaryOp(const Token* tok, int op);
|
||||
Expr* ParseDerefOp(const Token* tok);
|
||||
|
||||
QualType ParseTypeName();
|
||||
Expr* ParseCastExpr();
|
||||
Expr* ParseRangeExpr();
|
||||
Expr* ParseMatmulExpr();
|
||||
Expr* ParseMultiplicativeExpr();
|
||||
Expr* ParseAdditiveExpr();
|
||||
Expr* ParseShiftExpr();
|
||||
Expr* ParseRelationalExpr();
|
||||
Expr* ParseEqualityExpr();
|
||||
Expr* ParseBitiwiseAndExpr();
|
||||
Expr* ParseBitwiseXorExpr();
|
||||
Expr* ParseBitwiseOrExpr();
|
||||
Expr* ParseLogicalAndExpr();
|
||||
Expr* ParseLogicalOrExpr();
|
||||
Expr* ParseConditionalExpr();
|
||||
Expr* ParseCommaExpr();
|
||||
Expr* ParseAssignExpr();
|
||||
|
||||
// Declarations
|
||||
CompoundStmt* ParseDecl();
|
||||
void ParseStaticAssert();
|
||||
QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec);
|
||||
QualType ParseSpecQual();
|
||||
int ParseAlignas();
|
||||
Type* ParseStructUnionSpec(bool isStruct);
|
||||
StructType* ParseStructUnionDecl(StructType* type);
|
||||
void ParseBitField(StructType* structType, const Token* tok, QualType type);
|
||||
Type* ParseEnumSpec();
|
||||
Type* ParseEnumerator(ArithmType* type);
|
||||
int ParseQual();
|
||||
QualType ParsePointer(QualType typePointedTo);
|
||||
DeclInfo ParseDeclarator(QualType type);
|
||||
QualType ParseArrayFuncDeclarator(const Token* ident, QualType base);
|
||||
int ParseArrayLength();
|
||||
TileType::ShapeInt ParseTileShape();
|
||||
bool ParseParamList(FuncType::ParamList& params);
|
||||
Object* ParseParamDecl();
|
||||
|
||||
QualType ParseAbstractDeclarator(QualType type);
|
||||
Identifier* ParseDirectDeclarator(QualType type,
|
||||
int storageSpec,
|
||||
int funcSpec,
|
||||
int align);
|
||||
// Initializer
|
||||
void ParseInitializer(Declaration* decl,
|
||||
QualType type,
|
||||
int offset,
|
||||
bool designated=false,
|
||||
bool forceBrace=false,
|
||||
unsigned char bitFieldBegin=0,
|
||||
unsigned char bitFieldWidth=0);
|
||||
void ParseArrayInitializer(Declaration* decl,
|
||||
ArrayType* type,
|
||||
int offset,
|
||||
bool designated);
|
||||
StructType::Iterator ParseStructDesignator(StructType* type,
|
||||
const std::string& name);
|
||||
void ParseStructInitializer(Declaration* decl,
|
||||
StructType* type,
|
||||
int offset,
|
||||
bool designated);
|
||||
bool ParseLiteralInitializer(Declaration* init,
|
||||
ArrayType* type,
|
||||
int offset);
|
||||
Declaration* ParseInitDeclarator(Identifier* ident);
|
||||
Declaration* ParseInitDeclaratorSub(Object* obj);
|
||||
|
||||
// Statements
|
||||
Stmt* ParseStmt();
|
||||
CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr);
|
||||
IfStmt* ParseIfStmt();
|
||||
CompoundStmt* ParseSwitchStmt();
|
||||
CompoundStmt* ParseWhileStmt();
|
||||
CompoundStmt* ParseDoStmt();
|
||||
ForStmt *ParseForStmt();
|
||||
JumpStmt* ParseGotoStmt();
|
||||
JumpStmt* ParseContinueStmt();
|
||||
JumpStmt* ParseBreakStmt();
|
||||
ReturnStmt* ParseReturnStmt();
|
||||
CompoundStmt* ParseLabelStmt(const Token* label);
|
||||
CompoundStmt* ParseCaseStmt();
|
||||
CompoundStmt* ParseDefaultStmt();
|
||||
Identifier* ProcessDeclarator(const Token* tok,
|
||||
QualType type, const ASTNode::AttrList &attrs,
|
||||
int storageSpec,
|
||||
int funcSpec,
|
||||
int align);
|
||||
// GNU extensions
|
||||
ASTNode::AttrList TryAttributeSpecList();
|
||||
void ParseAttributeSpec(ASTNode::AttrList &attrList);
|
||||
ASTNode::Attr ParseAttribute();
|
||||
bool IsTypeName(const Token* tok) const{
|
||||
if (tok->IsTypeSpecQual())
|
||||
return true;
|
||||
|
||||
if (tok->IsIdentifier()) {
|
||||
auto ident = curScope_->Find(tok);
|
||||
if (ident && ident->ToTypeName())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool IsType(const Token* tok) const{
|
||||
if (tok->IsDecl())
|
||||
return true;
|
||||
|
||||
if (tok->IsIdentifier()) {
|
||||
auto ident = curScope_->Find(tok);
|
||||
return (ident && ident->ToTypeName());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
void EnsureInteger(Expr* expr) {
|
||||
if (!expr->Type()->IsInteger()) {
|
||||
Error(expr, "expect integer expression");
|
||||
}
|
||||
}
|
||||
|
||||
void EnterBlock(FuncType* funcType=nullptr);
|
||||
void ExitBlock() { curScope_ = curScope_->Parent(); }
|
||||
void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); }
|
||||
void ExitProto() { curScope_ = curScope_->Parent(); }
|
||||
FuncDef* EnterFunc(Identifier* ident);
|
||||
void ExitFunc();
|
||||
|
||||
LabelStmt* FindLabel(const std::string& label) {
|
||||
auto ret = curLabels_.find(label);
|
||||
if (curLabels_.end() == ret)
|
||||
return nullptr;
|
||||
return ret->second;
|
||||
}
|
||||
void AddLabel(const std::string& label, LabelStmt* labelStmt) {
|
||||
assert(nullptr == FindLabel(label));
|
||||
curLabels_[label] = labelStmt;
|
||||
}
|
||||
TranslationUnit* Unit() { return unit_; }
|
||||
FuncDef* CurFunc() { return curFunc_; }
|
||||
const TokenSequence& ts() const { return ts_; }
|
||||
|
||||
protected:
|
||||
static bool IsBuiltin(FuncType* type);
|
||||
static bool IsBuiltin(const std::string& name);
|
||||
static Identifier* GetBuiltin(const Token* tok);
|
||||
static void DefineBuiltins();
|
||||
|
||||
static FuncType* vaStartType_;
|
||||
static FuncType* vaArgType_;
|
||||
|
||||
// The root of the AST
|
||||
TranslationUnit* unit_;
|
||||
|
||||
TokenSequence& ts_;
|
||||
|
||||
// It is not the real scope,
|
||||
// It contains all external symbols(resolved and not resolved)
|
||||
Scope* externalSymbols_;
|
||||
|
||||
const Token* errTok_;
|
||||
Scope* curScope_;
|
||||
FuncDef* curFunc_;
|
||||
LabelMap curLabels_;
|
||||
LabelJumpList unresolvedJumps_;
|
||||
|
||||
LabelStmt* breakDest_;
|
||||
LabelStmt* continueDest_;
|
||||
CaseLabelList* caseLabels_;
|
||||
LabelStmt* defaultLabel_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,86 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_SCANNER_H_
|
||||
#define _WGTCC_SCANNER_H_
|
||||
|
||||
#include "error.h"
|
||||
#include "encoding.h"
|
||||
#include "token.h"
|
||||
|
||||
#include <string>
|
||||
#include <cassert>
|
||||
|
||||
|
||||
class Scanner {
|
||||
public:
|
||||
explicit Scanner(const Token* tok)
|
||||
: Scanner(&tok->str_, tok->loc_) {}
|
||||
Scanner(const std::string* text, const SourceLocation& loc)
|
||||
: Scanner(text, loc.filename_, loc.line_, loc.column_) {}
|
||||
explicit Scanner(const std::string* text,
|
||||
const std::string* filename=nullptr,
|
||||
unsigned line=1, unsigned column=1)
|
||||
: text_(text), tok_(Token::END) {
|
||||
// TODO(wgtdkp): initialization
|
||||
p_ = &(*text_)[0];
|
||||
loc_ = {filename, p_, line, 1};
|
||||
}
|
||||
|
||||
virtual ~Scanner() {}
|
||||
Scanner(const Scanner& other) = delete;
|
||||
Scanner& operator=(const Scanner& other) = delete;
|
||||
|
||||
// Scan plain text and generate tokens in ts.
|
||||
// The param 'ts' need not be empty, if so, the tokens
|
||||
// are inserted at the *header* of 'ts'.
|
||||
// The param 'ws' tells if there is leading white space
|
||||
// before this token, it is only SkipComment() that will
|
||||
// set this param.
|
||||
Token* Scan(bool ws=false);
|
||||
void Tokenize(TokenSequence& ts);
|
||||
static std::string ScanHeadName(const Token* lhs, const Token* rhs);
|
||||
Encoding ScanCharacter(int& val);
|
||||
Encoding ScanLiteral(std::string& val);
|
||||
std::string ScanIdentifier();
|
||||
|
||||
private:
|
||||
Token* SkipIdentifier();
|
||||
Token* SkipNumber();
|
||||
Token* SkipLiteral();
|
||||
Token* SkipCharacter();
|
||||
Token* MakeToken(int tag);
|
||||
Token* MakeNewLine();
|
||||
Encoding ScanEncoding(int c);
|
||||
int ScanEscaped();
|
||||
int ScanHexEscaped();
|
||||
int ScanOctEscaped(int c);
|
||||
int ScanUCN(int len);
|
||||
void SkipWhiteSpace();
|
||||
void SkipComment();
|
||||
bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); }
|
||||
bool IsOctal(int c) { return '0' <= c && c <= '7'; }
|
||||
int XDigit(int c);
|
||||
bool Empty() const { return *p_ == 0; }
|
||||
int Peek();
|
||||
bool Test(int c) { return Peek() == c; };
|
||||
int Next();
|
||||
void PutBack();
|
||||
bool Try(int c) {
|
||||
if (Peek() == c) {
|
||||
Next();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
void Mark() { tok_.loc_ = loc_; };
|
||||
|
||||
const std::string* text_;
|
||||
SourceLocation loc_;
|
||||
Token tok_;
|
||||
const char* p_;
|
||||
};
|
||||
|
||||
|
||||
std::string* ReadFile(const std::string& filename);
|
||||
|
||||
#endif
|
@@ -1,72 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_SCOPE_H_
|
||||
#define _WGTCC_SCOPE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
class Identifier;
|
||||
class Token;
|
||||
|
||||
|
||||
enum ScopeType {
|
||||
S_FILE,
|
||||
S_PROTO,
|
||||
S_BLOCK,
|
||||
S_FUNC,
|
||||
};
|
||||
|
||||
|
||||
class Scope {
|
||||
friend class StructType;
|
||||
using TagList = std::vector<Identifier*>;
|
||||
using IdentMap = std::map<std::string, Identifier*>;
|
||||
|
||||
public:
|
||||
explicit Scope(Scope* parent, enum ScopeType type)
|
||||
: parent_(parent), type_(type) {}
|
||||
~Scope() {}
|
||||
Scope* Parent() { return parent_; }
|
||||
void SetParent(Scope* parent) { parent_ = parent; }
|
||||
enum ScopeType Type() const { return type_; }
|
||||
|
||||
Identifier* Find(const Token* tok);
|
||||
Identifier* FindInCurScope(const Token* tok);
|
||||
Identifier* FindTag(const Token* tok);
|
||||
Identifier* FindTagInCurScope(const Token* tok);
|
||||
TagList AllTagsInCurScope() const;
|
||||
|
||||
void Insert(Identifier* ident);
|
||||
void Insert(const std::string& name, Identifier* ident);
|
||||
void InsertTag(Identifier* ident);
|
||||
void Print();
|
||||
bool operator==(const Scope& other) const { return type_ == other.type_; }
|
||||
IdentMap::iterator begin() { return identMap_.begin(); }
|
||||
IdentMap::iterator end() { return identMap_.end(); }
|
||||
size_t size() const { return identMap_.size(); }
|
||||
|
||||
private:
|
||||
Identifier* Find(const std::string& name);
|
||||
Identifier* FindInCurScope(const std::string& name);
|
||||
Identifier* FindTag(const std::string& name);
|
||||
Identifier* FindTagInCurScope(const std::string& name);
|
||||
std::string TagName(const std::string& name) {
|
||||
return name + "@:tag";
|
||||
}
|
||||
static bool IsTagName(const std::string& name) {
|
||||
return name.size() > 5 && name[name.size() - 5] == '@';
|
||||
}
|
||||
const Scope& operator=(const Scope& other);
|
||||
Scope(const Scope& scope);
|
||||
|
||||
Scope* parent_;
|
||||
enum ScopeType type_;
|
||||
|
||||
IdentMap identMap_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,434 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_TOKEN_H_
|
||||
#define _WGTCC_TOKEN_H_
|
||||
|
||||
#include "error.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
class Generator;
|
||||
class Parser;
|
||||
class Scanner;
|
||||
class Token;
|
||||
class TokenSequence;
|
||||
|
||||
using HideSet = std::set<std::string>;
|
||||
using TokenList = std::list<const Token*>;
|
||||
|
||||
|
||||
struct SourceLocation {
|
||||
const std::string* filename_;
|
||||
const char* lineBegin_;
|
||||
unsigned line_;
|
||||
unsigned column_;
|
||||
|
||||
const char* Begin() const {
|
||||
return lineBegin_ + column_ - 1;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Token {
|
||||
friend class Scanner;
|
||||
public:
|
||||
enum {
|
||||
// Punctuators
|
||||
LPAR = '(',
|
||||
RPAR = ')',
|
||||
LSQB = '[',
|
||||
RSQB = ']',
|
||||
COLON = ':',
|
||||
COMMA = ',',
|
||||
SEMI = ';',
|
||||
ADD = '+',
|
||||
SUB = '-',
|
||||
MUL = '*',
|
||||
DIV = '/',
|
||||
OR = '|',
|
||||
AND = '&',
|
||||
XOR = '^',
|
||||
LESS = '<',
|
||||
GREATER = '>',
|
||||
EQUAL = '=',
|
||||
DOT = '.',
|
||||
MOD = '%',
|
||||
LBRACE = '{',
|
||||
RBRACE = '}',
|
||||
TILDE = '~',
|
||||
NOT = '!',
|
||||
COND = '?',
|
||||
SHARP = '#',
|
||||
MATMUL = '@',
|
||||
NEW_LINE = '\n',
|
||||
|
||||
DSHARP = 128, // '##'
|
||||
PTR,
|
||||
INC,
|
||||
DEC,
|
||||
LEFT,
|
||||
RIGHT,
|
||||
LE,
|
||||
GE,
|
||||
EQ,
|
||||
NE,
|
||||
LOGICAL_AND,
|
||||
LOGICAL_OR,
|
||||
|
||||
MUL_ASSIGN,
|
||||
DIV_ASSIGN,
|
||||
MOD_ASSIGN,
|
||||
ADD_ASSIGN,
|
||||
SUB_ASSIGN,
|
||||
LEFT_ASSIGN,
|
||||
RIGHT_ASSIGN,
|
||||
AND_ASSIGN,
|
||||
XOR_ASSIGN,
|
||||
OR_ASSIGN,
|
||||
|
||||
ELLIPSIS,
|
||||
MASKED_DEREF,
|
||||
// Punctuators end
|
||||
|
||||
// KEYWORD BEGIN
|
||||
// TYPE QUALIFIER BEGIN
|
||||
CONST,
|
||||
RESTRICT,
|
||||
VOLATILE,
|
||||
ATOMIC,
|
||||
// TYPE QUALIFIER END
|
||||
|
||||
// TYPE SPECIFIER BEGIN
|
||||
VOID,
|
||||
CHAR,
|
||||
SHORT,
|
||||
INT,
|
||||
LONG,
|
||||
HALF,
|
||||
FLOAT,
|
||||
DOUBLE,
|
||||
SIGNED,
|
||||
UNSIGNED,
|
||||
BOOL, // _Bool
|
||||
COMPLEX, // _Complex
|
||||
STRUCT,
|
||||
UNION,
|
||||
ENUM,
|
||||
// TYPE SPECIFIER END
|
||||
|
||||
ATTRIBUTE, // GNU extension __attribute__
|
||||
// FUNCTION SPECIFIER BEGIN
|
||||
INLINE,
|
||||
NORETURN, // _Noreturn
|
||||
// FUNCTION SPECIFIER END
|
||||
|
||||
// TILE ARITHMETICS BEGIN
|
||||
NEWAXIS,
|
||||
MAX,
|
||||
MIN,
|
||||
// TILE ARITHMETICS END
|
||||
|
||||
ALIGNAS, // _Alignas
|
||||
// For syntactic convenience
|
||||
STATIC_ASSERT, // _Static_assert
|
||||
// STORAGE CLASS SPECIFIER BEGIN
|
||||
TYPEDEF,
|
||||
EXTERN,
|
||||
STATIC,
|
||||
THREAD, // _Thread_local
|
||||
AUTO,
|
||||
GLOBAL,
|
||||
CMEM, // constant memory
|
||||
|
||||
// STORAGE CLASS SPECIFIER END
|
||||
BREAK,
|
||||
CASE,
|
||||
CONTINUE,
|
||||
DEFAULT,
|
||||
DO,
|
||||
ELSE,
|
||||
FOR,
|
||||
GOTO,
|
||||
IF,
|
||||
RETURN,
|
||||
SIZEOF,
|
||||
SWITCH,
|
||||
WHILE,
|
||||
ALIGNOF, // _Alignof
|
||||
GENERIC, // _Generic
|
||||
IMAGINARY, // _Imaginary
|
||||
// function keywords
|
||||
BITCAST,
|
||||
EXP,
|
||||
LOG,
|
||||
SQRTF,
|
||||
// KEYWORD END
|
||||
|
||||
IDENTIFIER,
|
||||
CONSTANT,
|
||||
I_CONSTANT,
|
||||
C_CONSTANT,
|
||||
F_CONSTANT,
|
||||
LITERAL,
|
||||
|
||||
// For the parser, a identifier is a typedef name or user defined type
|
||||
POSTFIX_INC,
|
||||
POSTFIX_DEC,
|
||||
PREFIX_INC,
|
||||
PREFIX_DEC,
|
||||
ADDR, // '&'
|
||||
DEREF, // '*'
|
||||
PLUS,
|
||||
MINUS,
|
||||
CAST,
|
||||
REDUCE,
|
||||
|
||||
// For preprocessor
|
||||
PP_IF,
|
||||
PP_IFDEF,
|
||||
PP_IFNDEF,
|
||||
PP_ELIF,
|
||||
PP_ELSE,
|
||||
PP_ENDIF,
|
||||
PP_INCLUDE,
|
||||
PP_DEFINE,
|
||||
PP_UNDEF,
|
||||
PP_LINE,
|
||||
PP_ERROR,
|
||||
PP_PRAGMA,
|
||||
PP_NONE,
|
||||
PP_EMPTY,
|
||||
|
||||
|
||||
IGNORE,
|
||||
INVALID,
|
||||
END,
|
||||
NOTOK = -1,
|
||||
};
|
||||
|
||||
static Token* New(int tag);
|
||||
static Token* New(const Token& other);
|
||||
static Token* New(int tag,
|
||||
const SourceLocation& loc,
|
||||
const std::string& str,
|
||||
bool ws=false);
|
||||
Token& operator=(const Token& other) {
|
||||
tag_ = other.tag_;
|
||||
ws_ = other.ws_;
|
||||
loc_ = other.loc_;
|
||||
str_ = other.str_;
|
||||
hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr;
|
||||
return *this;
|
||||
}
|
||||
virtual ~Token() {}
|
||||
|
||||
// Token::NOTOK represents not a kw.
|
||||
static int KeyWordTag(const std::string& key) {
|
||||
auto kwIter = kwTypeMap_.find(key);
|
||||
if (kwTypeMap_.end() == kwIter)
|
||||
return Token::NOTOK; // Not a key word type
|
||||
return kwIter->second;
|
||||
}
|
||||
static bool IsKeyWord(const std::string& name);
|
||||
static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; }
|
||||
bool IsKeyWord() const { return IsKeyWord(tag_); }
|
||||
bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; }
|
||||
bool IsLiteral() const { return tag_ == LITERAL; }
|
||||
bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; }
|
||||
bool IsIdentifier() const { return IDENTIFIER == tag_; }
|
||||
bool IsEOF() const { return tag_ == Token::END; }
|
||||
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
|
||||
bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; }
|
||||
static const char* Lexeme(int tag) {
|
||||
auto iter = tagLexemeMap_.find(tag);
|
||||
if (iter == tagLexemeMap_.end())
|
||||
return nullptr;
|
||||
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
int tag_;
|
||||
|
||||
// 'ws_' standards for weither there is preceding white space
|
||||
// This is to simplify the '#' operator(stringize) in macro expansion
|
||||
bool ws_ { false };
|
||||
SourceLocation loc_;
|
||||
|
||||
std::string str_;
|
||||
HideSet* hs_ { nullptr };
|
||||
|
||||
private:
|
||||
explicit Token(int tag): tag_(tag) {}
|
||||
Token(int tag, const SourceLocation& loc,
|
||||
const std::string& str, bool ws=false)
|
||||
: tag_(tag), ws_(ws), loc_(loc), str_(str) {}
|
||||
|
||||
Token(const Token& other) {
|
||||
*this = other;
|
||||
}
|
||||
|
||||
static const std::unordered_map<std::string, int> kwTypeMap_;
|
||||
static const std::unordered_map<int, const char*> tagLexemeMap_;
|
||||
};
|
||||
|
||||
|
||||
class TokenSequence {
|
||||
friend class Preprocessor;
|
||||
|
||||
public:
|
||||
TokenSequence(): tokList_(new TokenList()),
|
||||
begin_(tokList_->begin()), end_(tokList_->end()) {}
|
||||
explicit TokenSequence(Token* tok) {
|
||||
TokenSequence();
|
||||
InsertBack(tok);
|
||||
}
|
||||
explicit TokenSequence(TokenList* tokList)
|
||||
: tokList_(tokList),
|
||||
begin_(tokList->begin()),
|
||||
end_(tokList->end()) {}
|
||||
TokenSequence(TokenList* tokList,
|
||||
TokenList::iterator begin,
|
||||
TokenList::iterator end)
|
||||
: tokList_(tokList), begin_(begin), end_(end) {}
|
||||
~TokenSequence() {}
|
||||
TokenSequence(const TokenSequence& other) { *this = other; }
|
||||
const TokenSequence& operator=(const TokenSequence& other) {
|
||||
tokList_ = other.tokList_;
|
||||
begin_ = other.begin_;
|
||||
end_ = other.end_;
|
||||
return *this;
|
||||
}
|
||||
void Copy(const TokenSequence& other) {
|
||||
tokList_ = new TokenList(other.begin_, other.end_);
|
||||
begin_ = tokList_->begin();
|
||||
end_ = tokList_->end();
|
||||
for (auto iter = begin_; iter != end_; ++iter)
|
||||
*iter = Token::New(**iter);
|
||||
}
|
||||
void UpdateHeadLocation(const SourceLocation& loc) {
|
||||
assert(!Empty());
|
||||
auto tok = const_cast<Token*>(Peek());
|
||||
tok->loc_ = loc;
|
||||
}
|
||||
void FinalizeSubst(bool leadingWS, const HideSet& hs) {
|
||||
auto ts = *this;
|
||||
while (!ts.Empty()) {
|
||||
auto tok = const_cast<Token*>(ts.Next());
|
||||
if (!tok->hs_)
|
||||
tok->hs_ = new HideSet(hs);
|
||||
else
|
||||
tok->hs_->insert(hs.begin(), hs.end());
|
||||
}
|
||||
// Even if the token sequence is empty
|
||||
const_cast<Token*>(Peek())->ws_ = leadingWS;
|
||||
}
|
||||
|
||||
const Token* Expect(int expect);
|
||||
bool Try(int tag) {
|
||||
if (Peek()->tag_ == tag) {
|
||||
Next();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool Test(int tag) { return Peek()->tag_ == tag; }
|
||||
const Token* Next() {
|
||||
auto ret = Peek();
|
||||
if (!ret->IsEOF()) {
|
||||
++begin_;
|
||||
Peek(); // May skip newline token, but why ?
|
||||
} else {
|
||||
++exceed_end;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
void PutBack() {
|
||||
assert(begin_ != tokList_->begin());
|
||||
if (exceed_end > 0) {
|
||||
--exceed_end;
|
||||
} else {
|
||||
--begin_;
|
||||
if ((*begin_)->tag_ == Token::NEW_LINE)
|
||||
PutBack();
|
||||
}
|
||||
}
|
||||
const Token* Peek() const;
|
||||
const Token* Peek2() {
|
||||
if (Empty())
|
||||
return Peek(); // Return the Token::END
|
||||
Next();
|
||||
auto ret = Peek();
|
||||
PutBack();
|
||||
return ret;
|
||||
}
|
||||
const Token* Back() const {
|
||||
auto back = end_;
|
||||
return *--back;
|
||||
}
|
||||
void PopBack() {
|
||||
assert(!Empty());
|
||||
assert(end_ == tokList_->end());
|
||||
auto size_eq1 = tokList_->back() == *begin_;
|
||||
tokList_->pop_back();
|
||||
end_ = tokList_->end();
|
||||
if (size_eq1)
|
||||
begin_ = end_;
|
||||
}
|
||||
TokenList::iterator Mark() { return begin_; }
|
||||
void ResetTo(TokenList::iterator mark) { begin_ = mark; }
|
||||
bool Empty() const { return Peek()->tag_ == Token::END; }
|
||||
void InsertBack(TokenSequence& ts) {
|
||||
auto pos = tokList_->insert(end_, ts.begin_, ts.end_);
|
||||
if (begin_ == end_) {
|
||||
begin_ = pos;
|
||||
}
|
||||
}
|
||||
void InsertBack(const Token* tok) {
|
||||
auto pos = tokList_->insert(end_, tok);
|
||||
if (begin_ == end_) {
|
||||
begin_ = pos;
|
||||
}
|
||||
}
|
||||
|
||||
// If there is preceding newline
|
||||
void InsertFront(TokenSequence& ts) {
|
||||
auto pos = GetInsertFrontPos();
|
||||
begin_ = tokList_->insert(pos, ts.begin_, ts.end_);
|
||||
}
|
||||
void InsertFront(const Token* tok) {
|
||||
auto pos = GetInsertFrontPos();
|
||||
begin_ = tokList_->insert(pos, tok);
|
||||
}
|
||||
bool IsBeginOfLine() const;
|
||||
TokenSequence GetLine();
|
||||
void SetParser(Parser* parser) { parser_ = parser; }
|
||||
void Print(FILE* fp=stdout) const;
|
||||
void Print(std::string *str) const;
|
||||
|
||||
private:
|
||||
// Find a insert position with no preceding newline
|
||||
TokenList::iterator GetInsertFrontPos() {
|
||||
auto pos = begin_;
|
||||
if (pos == tokList_->begin())
|
||||
return pos;
|
||||
--pos;
|
||||
while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE)
|
||||
--pos;
|
||||
return ++pos;
|
||||
}
|
||||
|
||||
TokenList* tokList_;
|
||||
mutable TokenList::iterator begin_;
|
||||
TokenList::iterator end_;
|
||||
Parser* parser_ {nullptr};
|
||||
int exceed_end {0};
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,453 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_TYPE_H_
|
||||
#define _WGTCC_TYPE_H_
|
||||
|
||||
#include "mem_pool.h"
|
||||
#include "scope.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
|
||||
|
||||
class Scope;
|
||||
class Token;
|
||||
class Expr;
|
||||
|
||||
class Type;
|
||||
class QualType;
|
||||
class VoidType;
|
||||
class Identifier;
|
||||
class Object;
|
||||
class Constant;
|
||||
|
||||
class ArithmType;
|
||||
class DerivedType;
|
||||
class ArrayType;
|
||||
class TileType;
|
||||
class FuncType;
|
||||
class PointerType;
|
||||
class StructType;
|
||||
class EnumType;
|
||||
|
||||
|
||||
enum {
|
||||
// Storage class specifiers
|
||||
S_TYPEDEF = 0x01,
|
||||
S_EXTERN = 0x02,
|
||||
S_STATIC = 0x04,
|
||||
S_THREAD = 0x08,
|
||||
S_CONSTANT = 0x10,
|
||||
S_GLOBAL = 0x20,
|
||||
|
||||
// Type specifier
|
||||
T_SIGNED = 0x40,
|
||||
T_UNSIGNED = 0x80,
|
||||
T_CHAR = 0x100,
|
||||
T_SHORT = 0x200,
|
||||
T_INT = 0x400,
|
||||
T_LONG = 0x800,
|
||||
T_VOID = 0x1000,
|
||||
T_HALF = 0x2000,
|
||||
T_FLOAT = 0x4000,
|
||||
T_DOUBLE = 0x8000,
|
||||
T_BOOL = 0x10000,
|
||||
T_COMPLEX = 0x20000,
|
||||
// T_ATOMIC = 0x40000,
|
||||
T_STRUCT_UNION = 0x80000,
|
||||
T_ENUM = 0x100000,
|
||||
T_TYPEDEF_NAME = 0x200000,
|
||||
|
||||
T_LLONG = 0x4000000,
|
||||
|
||||
// Function specifier
|
||||
F_INLINE = 0x8000000,
|
||||
F_NORETURN = 0x10000000,
|
||||
};
|
||||
|
||||
|
||||
struct Qualifier {
|
||||
enum {
|
||||
CONST = 0x01,
|
||||
RESTRICT = 0x02,
|
||||
VOLATILE = 0x04,
|
||||
CMEM = 0x08,
|
||||
MASK = CONST | RESTRICT | VOLATILE | CMEM
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
class QualType {
|
||||
public:
|
||||
QualType(Type* ptr, int quals=0x00)
|
||||
: ptr_(reinterpret_cast<intptr_t>(ptr)) {
|
||||
assert((quals & ~Qualifier::MASK) == 0);
|
||||
ptr_ |= quals;
|
||||
}
|
||||
|
||||
operator bool() const { return !IsNull(); }
|
||||
bool IsNull() const { return GetPtr() == nullptr; }
|
||||
const Type* GetPtr() const {
|
||||
return reinterpret_cast<const Type*>(ptr_ & ~Qualifier::MASK);
|
||||
}
|
||||
Type* GetPtr() {
|
||||
return reinterpret_cast<Type*>(ptr_ & ~Qualifier::MASK);
|
||||
}
|
||||
Type& operator*() { return *GetPtr(); }
|
||||
const Type& operator*() const { return *GetPtr(); }
|
||||
Type* operator->() { return GetPtr(); }
|
||||
const Type* operator->() const { return GetPtr(); }
|
||||
|
||||
// Indicate whether the specified types are identical(exclude qualifiers).
|
||||
friend bool operator==(QualType lhs, QualType rhs) {
|
||||
return lhs.operator->() == rhs.operator->();
|
||||
}
|
||||
friend bool operator!=(QualType lhs, QualType rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
int Qual() const { return ptr_ & 0x07; }
|
||||
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
|
||||
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
|
||||
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
|
||||
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
|
||||
|
||||
private:
|
||||
intptr_t ptr_;
|
||||
};
|
||||
|
||||
|
||||
class Type {
|
||||
public:
|
||||
static const int intWidth_ = 4;
|
||||
static const int machineWidth_ = 8;
|
||||
|
||||
bool operator!=(const Type& other) const = delete;
|
||||
bool operator==(const Type& other) const = delete;
|
||||
|
||||
virtual bool Compatible(const Type& other) const {
|
||||
return complete_ == other.complete_;
|
||||
}
|
||||
|
||||
virtual ~Type() {}
|
||||
|
||||
// For Debugging
|
||||
virtual std::string Str() const = 0;
|
||||
virtual int Width() const = 0;
|
||||
virtual int Align() const { return Width(); }
|
||||
static int MakeAlign(int offset, int align) {
|
||||
if ((offset % align) == 0)
|
||||
return offset;
|
||||
if (offset >= 0)
|
||||
return offset + align - (offset % align);
|
||||
else
|
||||
return offset - align - (offset % align);
|
||||
}
|
||||
|
||||
static QualType MayCast(QualType type, bool inProtoScope=false);
|
||||
bool Complete() const { return complete_; }
|
||||
void SetComplete(bool complete) const { complete_ = complete; }
|
||||
|
||||
bool IsReal() const { return IsInteger() || IsFloat(); };
|
||||
virtual bool IsScalar() const { return false; }
|
||||
virtual bool IsFloat() const { return false; }
|
||||
virtual bool IsInteger() const { return false; }
|
||||
virtual bool IsBool() const { return false; }
|
||||
virtual bool IsVoidPointer() const { return false; }
|
||||
virtual bool IsUnsigned() const { return false; }
|
||||
virtual bool IsTile() const { return ToTile() != nullptr; }
|
||||
|
||||
const Type* ScalarType() const;
|
||||
Type* ScalarType();
|
||||
|
||||
virtual VoidType* ToVoid() { return nullptr; }
|
||||
virtual const VoidType* ToVoid() const { return nullptr; }
|
||||
virtual ArithmType* ToArithm() { return nullptr; }
|
||||
virtual const ArithmType* ToArithm() const { return nullptr; }
|
||||
virtual ArrayType* ToArray() { return nullptr; }
|
||||
virtual const ArrayType* ToArray() const { return nullptr; }
|
||||
virtual TileType* ToTile() { return nullptr; }
|
||||
virtual const TileType* ToTile() const { return nullptr; }
|
||||
virtual FuncType* ToFunc() { return nullptr; }
|
||||
virtual const FuncType* ToFunc() const { return nullptr; }
|
||||
virtual PointerType* ToPointer() { return nullptr; }
|
||||
virtual const PointerType* ToPointer() const { return nullptr; }
|
||||
virtual DerivedType* ToDerived() { return nullptr; }
|
||||
virtual const DerivedType* ToDerived() const { return nullptr; }
|
||||
virtual StructType* ToStruct() { return nullptr; }
|
||||
virtual const StructType* ToStruct() const { return nullptr; }
|
||||
|
||||
protected:
|
||||
Type(MemPool* pool, bool complete)
|
||||
: complete_(complete), pool_(pool) {}
|
||||
|
||||
mutable bool complete_;
|
||||
MemPool* pool_;
|
||||
};
|
||||
|
||||
|
||||
class VoidType : public Type {
|
||||
public:
|
||||
static VoidType* New();
|
||||
virtual ~VoidType() {}
|
||||
virtual VoidType* ToVoid() { return this; }
|
||||
virtual const VoidType* ToVoid() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const { return other.ToVoid(); }
|
||||
virtual int Width() const {
|
||||
// Non-standard GNU extension
|
||||
return 1;
|
||||
}
|
||||
virtual std::string Str() const { return "void:1"; }
|
||||
|
||||
protected:
|
||||
explicit VoidType(MemPool* pool): Type(pool, false) {}
|
||||
};
|
||||
|
||||
|
||||
class ArithmType : public Type {
|
||||
public:
|
||||
static ArithmType* New(int typeSpec);
|
||||
|
||||
virtual ~ArithmType() {}
|
||||
virtual ArithmType* ToArithm() { return this; }
|
||||
virtual const ArithmType* ToArithm() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const {
|
||||
// C11 6.2.7 [1]: Two types have compatible type if their types are the same
|
||||
// But I would to loose this constraints: integer and pointer are compatible
|
||||
// if (IsInteger() && other.ToPointer())
|
||||
// return other.Compatible(*this);
|
||||
return this == &other;
|
||||
}
|
||||
|
||||
virtual int Width() const;
|
||||
virtual std::string Str() const;
|
||||
virtual bool IsScalar() const { return true; }
|
||||
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
|
||||
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
|
||||
virtual bool IsFloat() const {
|
||||
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
|
||||
}
|
||||
virtual bool IsBool() const { return tag_ & T_BOOL; }
|
||||
bool IsComplex() const { return tag_ & T_COMPLEX; }
|
||||
int Tag() const { return tag_; }
|
||||
int Rank() const;
|
||||
static ArithmType* IntegerPromote(ArithmType* type) {
|
||||
assert(type->IsInteger());
|
||||
if (type->Rank() < ArithmType::New(T_INT)->Rank())
|
||||
return ArithmType::New(T_INT);
|
||||
return type;
|
||||
}
|
||||
static ArithmType* MaxType(ArithmType* lhsType,
|
||||
ArithmType* rhsType);
|
||||
|
||||
protected:
|
||||
explicit ArithmType(MemPool* pool, int spec)
|
||||
: Type(pool, true), tag_(Spec2Tag(spec)) {}
|
||||
|
||||
private:
|
||||
static int Spec2Tag(int spec);
|
||||
|
||||
int tag_;
|
||||
};
|
||||
|
||||
|
||||
class DerivedType : public Type {
|
||||
public:
|
||||
QualType Derived() const { return derived_; }
|
||||
void SetDerived(QualType derived) { derived_ = derived; }
|
||||
virtual DerivedType* ToDerived() { return this; }
|
||||
virtual const DerivedType* ToDerived() const { return this; }
|
||||
|
||||
protected:
|
||||
DerivedType(MemPool* pool, QualType derived)
|
||||
: Type(pool, true), derived_(derived) {}
|
||||
|
||||
QualType derived_;
|
||||
};
|
||||
|
||||
|
||||
class PointerType : public DerivedType {
|
||||
public:
|
||||
static PointerType* New(QualType derived);
|
||||
virtual ~PointerType() {}
|
||||
virtual PointerType* ToPointer() { return this; }
|
||||
virtual const PointerType* ToPointer() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const { return 8; }
|
||||
virtual bool IsScalar() const { return true; }
|
||||
virtual bool IsVoidPointer() const { return derived_->ToVoid(); }
|
||||
virtual std::string Str() const {
|
||||
return derived_->Str() + "*:" + std::to_string(Width());
|
||||
}
|
||||
|
||||
protected:
|
||||
PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {}
|
||||
};
|
||||
|
||||
|
||||
class ArrayType : public DerivedType {
|
||||
public:
|
||||
static ArrayType* New(int len, QualType eleType);
|
||||
static ArrayType* New(Expr* expr, QualType eleType);
|
||||
virtual ~ArrayType() { /*delete derived_;*/ }
|
||||
|
||||
virtual ArrayType* ToArray() { return this; }
|
||||
virtual const ArrayType* ToArray() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const {
|
||||
return Complete() ? (derived_->Width() * len_): 0;
|
||||
}
|
||||
virtual int Align() const { return derived_->Align(); }
|
||||
virtual std::string Str() const {
|
||||
return derived_->Str() + "[]:" + std::to_string(Width());
|
||||
}
|
||||
|
||||
int GetElementOffset(int idx) const { return derived_->Width() * idx; }
|
||||
int Len() const { return len_; }
|
||||
void SetLen(int len) { len_ = len; }
|
||||
bool Variadic() const { return lenExpr_ != nullptr; }
|
||||
|
||||
protected:
|
||||
ArrayType(MemPool* pool, Expr* lenExpr, QualType derived)
|
||||
: DerivedType(pool, derived),
|
||||
lenExpr_(lenExpr), len_(0) {
|
||||
SetComplete(false);
|
||||
}
|
||||
|
||||
ArrayType(MemPool* pool, int len, QualType derived)
|
||||
: DerivedType(pool, derived),
|
||||
lenExpr_(nullptr), len_(len) {
|
||||
SetComplete(len_ >= 0);
|
||||
}
|
||||
const Expr* lenExpr_;
|
||||
int len_;
|
||||
};
|
||||
|
||||
class TileType : public DerivedType {
|
||||
public:
|
||||
using ShapeExpr = std::vector<Expr*>;
|
||||
using ShapeInt = std::vector<int>;
|
||||
|
||||
public:
|
||||
static TileType* New(const ShapeInt& shape, QualType eleType);
|
||||
virtual ~TileType() { }
|
||||
|
||||
virtual TileType* ToTile() { return this; }
|
||||
virtual const TileType* ToTile() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
|
||||
virtual int Align() const { return derived_->Align(); }
|
||||
virtual std::string Str() const {
|
||||
return derived_->Str() + "[{}]:" + std::to_string(Width());
|
||||
}
|
||||
|
||||
ShapeInt Shape() { return shape_; }
|
||||
|
||||
int NumEle() const {
|
||||
int ret = 1;
|
||||
for(int s: shape_)
|
||||
ret *= s;
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool CheckPow2NumEl() const {
|
||||
int n = NumEle();
|
||||
return n && !(n & (n - 1));
|
||||
}
|
||||
|
||||
protected:
|
||||
TileType(MemPool* pool, const ShapeInt& shape, QualType derived);
|
||||
|
||||
protected:
|
||||
ShapeExpr shapeExpr_;
|
||||
ShapeInt shape_;
|
||||
};
|
||||
|
||||
class FuncType : public DerivedType {
|
||||
public:
|
||||
using ParamList = std::vector<Object*>;
|
||||
|
||||
public:
|
||||
static FuncType* New(QualType derived,
|
||||
int funcSpec,
|
||||
bool variadic,
|
||||
const ParamList& params);
|
||||
virtual ~FuncType() {}
|
||||
virtual FuncType* ToFunc() { return this; }
|
||||
virtual const FuncType* ToFunc() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const { return 1; }
|
||||
virtual std::string Str() const;
|
||||
const ParamList& Params() const { return params_; }
|
||||
void SetParams(const ParamList& params) { params_ = params; }
|
||||
bool Variadic() const { return variadic_; }
|
||||
bool IsInline() const { return inlineNoReturn_ & F_INLINE; }
|
||||
bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; }
|
||||
|
||||
protected:
|
||||
FuncType(MemPool* pool, QualType derived, int inlineReturn,
|
||||
bool variadic, const ParamList& params)
|
||||
: DerivedType(pool, derived), inlineNoReturn_(inlineReturn),
|
||||
variadic_(variadic), params_(params) {
|
||||
SetComplete(false);
|
||||
}
|
||||
|
||||
private:
|
||||
int inlineNoReturn_;
|
||||
bool variadic_;
|
||||
ParamList params_;
|
||||
};
|
||||
|
||||
|
||||
class StructType : public Type {
|
||||
public:
|
||||
using MemberList = std::list<Object*>;
|
||||
using Iterator = std::list<Object*>::iterator;
|
||||
|
||||
public:
|
||||
static StructType* New(bool isStruct,
|
||||
bool hasTag,
|
||||
Scope* parent);
|
||||
virtual ~StructType() {}
|
||||
virtual StructType* ToStruct() { return this; }
|
||||
virtual const StructType* ToStruct() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const { return width_; }
|
||||
virtual int Align() const { return align_; }
|
||||
virtual std::string Str() const;
|
||||
|
||||
// struct/union
|
||||
void AddMember(Object* member);
|
||||
void AddBitField(Object* member, int offset);
|
||||
bool IsStruct() const { return isStruct_; }
|
||||
Object* GetMember(const std::string& member);
|
||||
Scope* MemberMap() { return memberMap_; }
|
||||
MemberList& Members() { return members_; }
|
||||
int Offset() const { return offset_; }
|
||||
bool HasTag() const { return hasTag_; }
|
||||
void MergeAnony(Object* anony);
|
||||
void Finalize();
|
||||
|
||||
protected:
|
||||
// Default is incomplete
|
||||
StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent);
|
||||
|
||||
StructType(const StructType& other);
|
||||
|
||||
private:
|
||||
void CalcWidth();
|
||||
|
||||
bool isStruct_;
|
||||
bool hasTag_;
|
||||
Scope* memberMap_;
|
||||
|
||||
MemberList members_;
|
||||
int offset_;
|
||||
int width_;
|
||||
int align_;
|
||||
int bitFieldAlign_;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _WGTCC_VISITOR_H_
|
||||
#define _WGTCC_VISITOR_H_
|
||||
|
||||
|
||||
class BinaryOp;
|
||||
class UnaryOp;
|
||||
class TransOp;
|
||||
class ConditionalOp;
|
||||
class FuncCall;
|
||||
class Identifier;
|
||||
class Object;
|
||||
class Enumerator;
|
||||
class Constant;
|
||||
class TempVar;
|
||||
|
||||
class Declaration;
|
||||
class IfStmt;
|
||||
class ForStmt;
|
||||
class JumpStmt;
|
||||
class ReturnStmt;
|
||||
class LabelStmt;
|
||||
class EmptyStmt;
|
||||
class CompoundStmt;
|
||||
class FuncDef;
|
||||
class TranslationUnit;
|
||||
|
||||
|
||||
class Visitor {
|
||||
public:
|
||||
virtual ~Visitor() {}
|
||||
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
|
||||
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
|
||||
virtual void VisitTransOp(TransOp* trans) = 0;
|
||||
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
|
||||
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
|
||||
virtual void VisitEnumerator(Enumerator* enumer) = 0;
|
||||
virtual void VisitIdentifier(Identifier* ident) = 0;
|
||||
virtual void VisitObject(Object* obj) = 0;
|
||||
virtual void VisitConstant(Constant* cons) = 0;
|
||||
virtual void VisitTempVar(TempVar* tempVar) = 0;
|
||||
|
||||
virtual void VisitDeclaration(Declaration* init) = 0;
|
||||
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
|
||||
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
|
||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
|
||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
|
||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
|
||||
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0;
|
||||
virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0;
|
||||
virtual void VisitFuncDef(FuncDef* funcDef) = 0;
|
||||
virtual void VisitTranslationUnit(TranslationUnit* unit) = 0;
|
||||
};
|
||||
|
||||
#endif
|
@@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_RUNTIME_ARG_H_
|
||||
#define _TRITON_RUNTIME_ARG_H_
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
class type;
|
||||
}
|
||||
|
||||
namespace driver{
|
||||
class buffer;
|
||||
}
|
||||
|
||||
namespace runtime {
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,34 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_RUNTIME_ERROR_H_
|
||||
#define _TRITON_RUNTIME_ERROR_H_
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
namespace triton {
|
||||
namespace runtime{
|
||||
namespace exception {
|
||||
|
||||
class base: public std::exception {};
|
||||
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
|
||||
|
||||
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
|
||||
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
|
||||
|
||||
class no_valid_configuration: public exception::base {
|
||||
public:
|
||||
no_valid_configuration(const std::string& err): err_(err) { }
|
||||
const char * what() const throw(){ return err_.c_str(); }
|
||||
private:
|
||||
std::string err_;
|
||||
};
|
||||
|
||||
|
||||
#undef TRITON_CREATE_RUNTIME_EXCEPTION
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -1,159 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef _TRITON_RUNTIME_FUNCTION_H_
|
||||
#define _TRITON_RUNTIME_FUNCTION_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
// codegen
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
#include "triton/runtime/error.h"
|
||||
|
||||
// driver forward declaration
|
||||
namespace triton {
|
||||
namespace driver{
|
||||
class module;
|
||||
class stream;
|
||||
class kernel;
|
||||
class context;
|
||||
class device;
|
||||
}
|
||||
}
|
||||
// ir forward declaration
|
||||
namespace triton{
|
||||
namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class context;
|
||||
}
|
||||
}
|
||||
|
||||
namespace triton{
|
||||
namespace runtime{
|
||||
|
||||
|
||||
/* ------------------------- */
|
||||
/* Compilation options */
|
||||
/* ------------------------- */
|
||||
|
||||
struct options_t {
|
||||
template<class T>
|
||||
T D(const std::string& name) const {
|
||||
return std::stoi(defines.at(name));
|
||||
}
|
||||
std::unordered_map<std::string, std::string> defines;
|
||||
int num_warps;
|
||||
};
|
||||
|
||||
/* ------------------------- */
|
||||
/* Runtime arguments */
|
||||
/* ------------------------- */
|
||||
|
||||
enum arg_type {
|
||||
INT1_T,
|
||||
INT8_T,
|
||||
INT16_T,
|
||||
INT32_T,
|
||||
INT64_T,
|
||||
HALF_T,
|
||||
FLOAT_T,
|
||||
DOUBLE_T,
|
||||
BUFFER_T
|
||||
};
|
||||
|
||||
inline size_t size_of(arg_type ty){
|
||||
switch(ty){
|
||||
case INT1_T : return 1;
|
||||
case INT8_T : return 1;
|
||||
case INT16_T : return 2;
|
||||
case INT32_T : return 4;
|
||||
case INT64_T : return 8;
|
||||
case HALF_T : return 2;
|
||||
case FLOAT_T : return 4;
|
||||
case DOUBLE_T: return 8;
|
||||
case BUFFER_T: return 8;
|
||||
default: throw std::runtime_error("unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void add_arg(std::stringstream& ss, T arg) {
|
||||
ss.write((char*)&arg, sizeof(T));
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------- */
|
||||
/* ------------------------- */
|
||||
|
||||
class kernel{
|
||||
public:
|
||||
typedef std::vector<size_t> grid_t;
|
||||
|
||||
public:
|
||||
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
|
||||
static std::tuple<std::shared_ptr<driver::module>,
|
||||
std::shared_ptr<driver::kernel>,
|
||||
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
|
||||
|
||||
public:
|
||||
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
|
||||
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
|
||||
std::string get_asm(const std::string &mode);
|
||||
|
||||
public:
|
||||
const options_t opt;
|
||||
|
||||
private:
|
||||
driver::device* dev_;
|
||||
// handles
|
||||
std::shared_ptr<ir::module> ir_;
|
||||
std::shared_ptr<driver::module> mod_;
|
||||
std::shared_ptr<driver::kernel> ker_;
|
||||
// shared mem
|
||||
size_t shared_mem_;
|
||||
};
|
||||
|
||||
struct config {
|
||||
std::map<std::string, std::string> defines;
|
||||
int num_warps;
|
||||
};
|
||||
|
||||
class function {
|
||||
public:
|
||||
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
|
||||
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
|
||||
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
|
||||
typedef std::vector<config> autotune_confs_t;
|
||||
|
||||
public:
|
||||
function(const std::string& src, const options_t& opt, driver::device *device,
|
||||
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
|
||||
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
const std::vector<arg_type> get_signature() { return sig_; }
|
||||
|
||||
private:
|
||||
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
|
||||
std::map<std::vector<uint64_t>, kernel*> cache_;
|
||||
std::vector<arg_type> sig_;
|
||||
std::vector<int> align_idxs_;
|
||||
std::vector<int> int_idxs_;
|
||||
std::vector<int> key_idxs_;
|
||||
std::vector<int> arg_size_;
|
||||
std::vector<int> arg_off_;
|
||||
std::vector<options_t> opts_;
|
||||
std::string src_;
|
||||
driver::device* device_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
Reference in New Issue
Block a user