[codegen] more cleaning
This commit is contained in:
@@ -148,20 +148,41 @@ private:
|
||||
|
||||
|
||||
class generator: public ir::visitor {
|
||||
private:
|
||||
Type *type(ir::type *ty);
|
||||
|
||||
private:
|
||||
void visit_hmma_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void visit_scanline_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add);
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
void visit_outer_dot(ir::dot_inst*, distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
|
||||
Type *type(ir::type *ty);
|
||||
void for_each(ir::value *x, const std::function<void(indices_t)>& fn);
|
||||
void get_value(ir::value *x, const indices_t& idx);
|
||||
Value* get_value(ir::value *x, const indices_t& idx);
|
||||
void set_value(ir::value *x, const indices_t& idx, Value* v);
|
||||
|
||||
public:
|
||||
|
||||
generator(LLVMContext *ctx,
|
||||
Function *fn,
|
||||
Builder *builder,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap,
|
||||
target *tgt,
|
||||
analysis::layout *layouts,
|
||||
analysis::align *alignment,
|
||||
analysis::allocation *alloc,
|
||||
Value *sh_mem_ptr,
|
||||
Value *offset_a_i, Value *offset_a_k,
|
||||
Value *offset_b_j, Value *offset_b_k,
|
||||
unsigned num_packs_0, unsigned num_packs_1,
|
||||
unsigned pack_size_0, unsigned pack_size_1,
|
||||
unsigned num_warps)
|
||||
: ctx_(ctx), fn_(fn), builder_(builder), vmap_(vmap), tmap_(tmap), tgt_(tgt),
|
||||
layouts_(layouts), alignment_(alignment), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr),
|
||||
offset_a_i_(offset_a_i), offset_a_k_(offset_a_k), offset_b_j_(offset_b_j), offset_b_k_(offset_b_k),
|
||||
num_packs_0_(num_packs_0), num_packs_1_(num_packs_1), pack_size_0_(pack_size_0), pack_size_1_(pack_size_1),
|
||||
num_warps_(num_warps) { }
|
||||
|
||||
|
||||
void visit_phi_node(ir::phi_node*);
|
||||
void visit_binary_operator(ir::binary_operator*);
|
||||
void visit_getelementptr_inst(ir::getelementptr_inst*);
|
||||
@@ -180,7 +201,6 @@ public:
|
||||
void visit_unmasked_store_inst(ir::unmasked_store_inst*);
|
||||
void visit_masked_store_inst(ir::masked_store_inst*);
|
||||
|
||||
void visit_retile_inst(ir::retile_inst*);
|
||||
void visit_reshape_inst(ir::reshape_inst*);
|
||||
void visit_splat_inst(ir::splat_inst*);
|
||||
void visit_broadcast_inst(ir::broadcast_inst*);
|
||||
@@ -209,8 +229,8 @@ private:
|
||||
Function *fn_;
|
||||
Builder *builder_;
|
||||
|
||||
std::map<ir::value *, Value *> vmap_;
|
||||
std::map<ir::value *, tile *> tmap_;
|
||||
std::map<ir::value *, Value *>& vmap_;
|
||||
std::map<ir::value *, tile *>& tmap_;
|
||||
target *tgt_;
|
||||
analysis::layout *layouts_;
|
||||
analysis::align *alignment_;
|
||||
@@ -235,8 +255,6 @@ private:
|
||||
|
||||
// LLVM conversions
|
||||
Type* llvm_type(ir::type *ty, LLVMContext &ctx);
|
||||
Value* llvm_value(ir::value *v, Builder &builder);
|
||||
Instruction* llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, Builder &builder);
|
||||
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
|
||||
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
|
||||
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
|
||||
@@ -256,37 +274,7 @@ private:
|
||||
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
|
||||
|
||||
// lower scalar instruction
|
||||
void lower_value(ir::value *src, Builder &builder, std::set<ir::value*>& seen);
|
||||
// lower tile instruction
|
||||
void lower_masked_store(ir::masked_store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_dynamic_program_idx(ir::make_range_dyn *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_copy_from_shared(ir::copy_from_shared_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
// matrix multiply
|
||||
void lower_hmma_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK);
|
||||
void lower_scanline_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
|
||||
distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
void lower_outer_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder,
|
||||
distributed_tile *TC, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD,
|
||||
Type *c_ty, Function *f_mul_add);
|
||||
void lower_dot(ir::dot_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
// load
|
||||
void lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
// element-wise
|
||||
void lower_elementwise(ir::instruction *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_tile_instruction(ir::instruction *src, Builder &builder);
|
||||
|
||||
|
||||
void lower_value(ir::value *src, Builder &builder, generator* gen, std::set<ir::value*>& seen);
|
||||
|
||||
public:
|
||||
selection(analysis::liveness* liveness, analysis::allocation *alloc,
|
||||
|
@@ -71,6 +71,8 @@ public:
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
// visit
|
||||
virtual void accept(visitor *v) = 0;
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
|
@@ -3,9 +3,11 @@
|
||||
#ifndef _TRITON_IR_VISITOR_H_
|
||||
#define _TRITON_IR_VISITOR_H_
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class instruction;
|
||||
|
||||
class phi_node;
|
||||
class binary_operator;
|
||||
@@ -13,6 +15,7 @@ class getelementptr_inst;
|
||||
|
||||
class icmp_inst;
|
||||
class fcmp_inst;
|
||||
class cast_inst;
|
||||
class trunc_inst;
|
||||
class z_ext_inst;
|
||||
class s_ext_inst;
|
||||
@@ -73,7 +76,7 @@ public:
|
||||
|
||||
virtual void visit_icmp_inst(icmp_inst*) = 0;
|
||||
virtual void visit_fcmp_inst(fcmp_inst*) = 0;
|
||||
virtual void visit_cast_inst(trunc_inst*) = 0;
|
||||
virtual void visit_cast_inst(cast_inst*) = 0;
|
||||
|
||||
virtual void visit_return_inst(return_inst*) = 0;
|
||||
virtual void visit_cond_branch_inst(cond_branch_inst*) = 0;
|
||||
@@ -85,7 +88,6 @@ public:
|
||||
virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0;
|
||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||
|
||||
virtual void visit_retile_inst(retile_inst*) = 0;
|
||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||
virtual void visit_broadcast_inst(broadcast_inst*) = 0;
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
Reference in New Issue
Block a user