[codegen] more cleaning

This commit is contained in:
Philippe Tillet
2019-10-13 02:26:30 -04:00
parent cb12fc1a87
commit e787ce0cab
5 changed files with 125 additions and 826 deletions

View File

@@ -148,20 +148,41 @@ private:
class generator: public ir::visitor {
private:
Type *type(ir::type *ty);
private:
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
Type *type(ir::type *ty);
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
void get_value(ir::value *x, const indices_t& idx);
Value* get_value(ir::value *x, const indices_t& idx);
void set_value(ir::value *x, const indices_t& idx, Value* v);
public:
generator(LLVMContext *ctx,
Function *fn,
Builder *builder,
std::map<ir::value *, Value *>& vmap,
std::map<ir::value *, tile *>& tmap,
target *tgt,
analysis::layout *layouts,
analysis::align *alignment,
analysis::allocation *alloc,
Value *sh_mem_ptr,
Value *offset_a_i, Value *offset_a_k,
Value *offset_b_j, Value *offset_b_k,
unsigned num_packs_0, unsigned num_packs_1,
unsigned pack_size_0, unsigned pack_size_1,
unsigned num_warps)
: ctx_(ctx), fn_(fn), builder_(builder), vmap_(vmap), tmap_(tmap), tgt_(tgt),
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
num_warps_(num_warps) { }
void visit_phi_node(ir::phi_node*);
void visit_binary_operator(ir::binary_operator*);
void visit_getelementptr_inst(ir::getelementptr_inst*);
@@ -180,7 +201,6 @@ public:
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
void visit_masked_store_inst(ir::masked_store_inst*);
void visit_retile_inst(ir::retile_inst*);
void visit_reshape_inst(ir::reshape_inst*);
void visit_splat_inst(ir::splat_inst*);
void visit_broadcast_inst(ir::broadcast_inst*);
@@ -209,8 +229,8 @@ private:
Function *fn_;
Builder *builder_;
std::map<ir::value *, Value *> vmap_;
std::map<ir::value *, tile *> tmap_;
std::map<ir::value *, Value *>& vmap_;
std::map<ir::value *, tile *>& tmap_;
target *tgt_;
analysis::layout *layouts_;
analysis::align *alignment_;
@@ -235,8 +255,6 @@ private:
// LLVM conversions
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
Value* llvm_value(ir::value *v, Builder &builder);
Instruction* llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, Builder &builder);
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
@@ -256,37 +274,7 @@ private:
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction
void lower_value(ir::value *src, Builder &builder, std::set<ir::value*>& seen);
// lower tile instruction
void lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_dynamic_program_idx(ir::make_range_dyn *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
// matrix multiply
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
void lower_scanline_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
Type *c_ty, Function *f_mul_add);
void lower_outer_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD,
Type *c_ty, Function *f_mul_add);
void lower_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
// load
void lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
// element-wise
void lower_elementwise(ir::instruction *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_tile_instruction(ir::instruction *src, Builder &builder);
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);
public:
selection(analysis::liveness* liveness, analysis::allocation *alloc,

View File

@@ -71,6 +71,8 @@ public:
}
// instruction id
value_id_t get_id() const { return id_; }
// visit
virtual void accept(visitor *v) = 0;
private:
basic_block *parent_;

View File

@@ -3,9 +3,11 @@
#ifndef _TRITON_IR_VISITOR_H_
#define _TRITON_IR_VISITOR_H_
namespace triton{
namespace ir{
class instruction;
class phi_node;
class binary_operator;
@@ -13,6 +15,7 @@ class getelementptr_inst;
class icmp_inst;
class fcmp_inst;
class cast_inst;
class trunc_inst;
class z_ext_inst;
class s_ext_inst;
@@ -73,7 +76,7 @@ public:
virtual void visit_icmp_inst(icmp_inst*) = 0;
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
virtual void visit_cast_inst(trunc_inst*) = 0;
virtual void visit_cast_inst(cast_inst*) = 0;
virtual void visit_return_inst(return_inst*) = 0;
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
@@ -85,7 +88,6 @@ public:
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
virtual void visit_retile_inst(retile_inst*) = 0;
virtual void visit_reshape_inst(reshape_inst*) = 0;
virtual void visit_splat_inst(splat_inst*) = 0;
virtual void visit_broadcast_inst(broadcast_inst*) = 0;

File diff suppressed because it is too large Load Diff

View File

@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};