diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 6ec396a1a..efebc102e 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -48,14 +48,14 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); // benchmark cublas NumericT alpha = 1; NumericT beta = 0; int32_t lda = AT ? K : M; int32_t ldb = BT ? N : K; int32_t ldc = M; - cublasGemmAlgo_t fastest; +// cublasGemmAlgo_t fastest; // cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K, // &alpha, da, lda, // db, ldb, &beta, @@ -109,6 +109,6 @@ int main() { // does the work for(config_t c: configs){ perf_t perf = c.perf(stream); - std::cout << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl; + std::cout << "// " << c.repr() << ", " << perf.triton << ", " << perf.cublas << std::endl; } } diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index fc10d4316..91ed2daaa 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -144,6 +144,6 @@ int main() { for(config_t c: configs){ std::string repr = c.repr(); perf_t perf = c.perf(stream); - std::cout << repr << ", " << perf.triton << ", " << perf.cublas << std::endl; + std::cout << "// " << repr << ", " << perf.triton << ", " << perf.cublas << std::endl; } } diff --git a/include/triton/codegen/optimize_cse.h b/include/triton/codegen/optimize_dce.h similarity index 87% rename from include/triton/codegen/optimize_cse.h rename to include/triton/codegen/optimize_dce.h index d718f318e..e40bafef5 100644 --- a/include/triton/codegen/optimize_cse.h +++ b/include/triton/codegen/optimize_dce.h @@ -14,9 +14,9 @@ namespace ir { namespace codegen{ class tune; -class optimize_cse { +class optimize_dce { public: - optimize_cse() {} + optimize_dce() {} void run(ir::module &mod); }; diff --git a/include/triton/codegen/optimize_trans.h b/include/triton/codegen/optimize_trans.h index beaace2a5..c6ec73b4d 100644 --- a/include/triton/codegen/optimize_trans.h +++ b/include/triton/codegen/optimize_trans.h @@ -19,7 +19,7 @@ namespace codegen{ class optimize_trans { private: - ir::value *replace_phi(ir::value* value, std::vector& to_delete, ir::builder &builder); + ir::value *replace_phi(ir::value* value, ir::builder &builder); public: optimize_trans() {} diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 4355bfce6..e1d2dbf0b 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -104,19 +104,10 @@ private: }; -// Fragmented tile -class fragmented_tile: public tile{ -public: - -private: - -}; - // Selection pass class selection{ typedef std::map vmap_t; typedef std::map tmap_t; - typedef std::map, llvm::BasicBlock*> pmap_t; private: // utils @@ -152,8 +143,6 @@ public: private: vmap_t vmap_; tmap_t tmap_; - pmap_t pmap_; - pmap_t last_block_; shmem_allocation *alloc_; tune *params_; target *tgt_; diff --git a/include/triton/dnn/heuristics.h b/include/triton/dnn/heuristics.h index bd9bc50aa..e2efe6df2 100644 --- a/include/triton/dnn/heuristics.h +++ b/include/triton/dnn/heuristics.h @@ -101,7 +101,7 @@ inline std::vector dot_search_space(bool AT, bool BT) { inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) { size_t TM = 128; size_t TN = 128; -// return {4, 8, 256, 8, 8, 64, 2, 2, 2, 2, 32, 32, 16, 1}; +// return {4, 4, 128, 8, 4, 128, 2, 2, 2, 2, 32, 32, 16, 1}; return params.at(trans_key_t{AT, BT}).at(size_key_t{TM, TN}); } diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 598c82454..9cee12c68 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -50,15 +50,14 @@ public: block_->get_inst_list().insert(insert_point_, inst); inst->set_parent(block_); inst->set_name(name); +// for(ir::value* op: inst->ops()) +// op->add_use(inst); return inst; } // terminator instructions value* create_br(basic_block *dest); value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest); value* create_ret_void(); - // Tile-level control flow -// value *create_mask(value *pred, const std::string &name = ""); -// value *create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name = ""); // Cast instructions value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = ""); value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = ""); @@ -120,6 +119,8 @@ public: // Input/Output value *create_load(value *arg, const std::string &name = ""); value *create_store(value *ptr, value *val, const std::string &name = ""); + value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = ""); + value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = ""); // Tile instruction value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 29b2678a3..d76ebf719 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -21,11 +21,6 @@ class context; class result_reference; class instruction: public user{ public: -// struct mask_info_t { -// value *pred; -// value *else_value; -// }; - virtual std::string repr_impl() const = 0; protected: @@ -38,11 +33,6 @@ public: const basic_block *get_parent() const { return parent_; } basic_block *get_parent() { return parent_; } void erase_from_parent(); -// // mask -// void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); } -// value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); } -// void set_mask_else(value *x) { resize_hidden(2); set_operand(get_num_operands() + 1, x); } -// value* get_mask_else() const { if(get_num_hidden() < 2) return nullptr; return get_operand(get_num_operands() + 1); } // helpers bool has_tile_result_or_op(); // repr @@ -56,8 +46,6 @@ public: unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} private: basic_block *parent_; -// value *pred_; -// value *mask_pred_; std::vector results_; std::map metadatas_; }; @@ -336,35 +324,6 @@ public: const std::string &name = "", instruction *next = nullptr); }; -//// mask -//class mask_inst: public instruction { -//private: -// std::string repr_impl() const { return "mask"; } -// mask_inst(ir::value *pred, const std::string &name, instruction *next); - -//public: -// static mask_inst* create(ir::value *pred, const std::string &name = "", instruction *next = nullptr); -//}; - -//// merge -//class psi_inst: public instruction { -//private: -// std::string repr_impl() const { return "merge"; } -// psi_inst(ir::value *mask_true, ir::value *value_true, -// ir::value *mask_false, ir::value *value_false, -// const std::string &name, instruction *next); - -//public: -// static psi_inst* create(ir::value *mask_true, ir::value *value_true, -// ir::value *mask_false, ir::value *value_false, -// const std::string &name = "", instruction *next = nullptr); -// ir::value *get_mask_true() { return get_operand(0); } -// ir::value *get_value_true() { return get_operand(1); } -// ir::value *get_mask_false() { return get_operand(2); } -// ir::value *get_value_false() { return get_operand(3); } - -//}; - //===----------------------------------------------------------------------===// // getelementptr_inst classes //===----------------------------------------------------------------------===// @@ -399,43 +358,78 @@ private: // load_inst/store_inst classes //===----------------------------------------------------------------------===// -class load_inst: public unary_inst{ -private: - std::string repr_impl() const { return "load"; } - load_inst(value *ptr, const std::string &name, instruction *next); +class io_inst: public instruction { +protected: + io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr); +public: +// value *get_mask() const; +// value *get_false_value() const; +}; + +class load_inst: public io_inst{ +protected: + load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next); private: + std::string repr_impl() const { return "load"; } static type *get_pointee_type(type *ty); public: // accessors value *get_pointer_operand() { return get_operand(0); } - value *get_mask() const; - value *set_mask(value *mask); // factory method - static load_inst* create(value *ptr, const std::string &name = "", + static load_inst* create(value *ptr, + const std::string &name = "", instruction *next = nullptr); - -private: - value *mask_; }; -class store_inst: public instruction{ +class masked_load_inst: public load_inst{ private: - std::string repr_impl() const { return "store"; } - store_inst(value *ptr, value *v, const std::string &name, instruction *next); + std::string repr_impl() const { return "masked_load"; } + masked_load_inst(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next); public: - value *get_pointer_operand() { return get_operand(0); } - value *get_value_operand() { return get_operand(1); } - value *get_mask() const; - value *set_mask(value *mask); + // accessors + value *get_mask_operand() { return get_operand(1); } + value *get_false_value_operand() { return get_operand(2); } // factory method - static store_inst* create(value* ptr, value *v, const std::string &name = "", - instruction *next = nullptr); + static masked_load_inst* create(value *ptr, value *mask, value *false_value, + const std::string &name = "", + instruction *next = nullptr); +}; + +class store_inst: public io_inst{ +protected: + store_inst(value *ptr, value *v, unsigned num_extra_ops, + const std::string &name, instruction *next); private: - ir::value *mask_; + std::string repr_impl() const { return "store"; } + +public: + // accessors + value *get_pointer_operand() { return get_operand(0); } + value *get_value_operand() { return get_operand(1); } + // factory method + static store_inst* create(value* ptr, value *v, + const std::string &name = "", + instruction *next = nullptr); +}; + +class masked_store_inst: public store_inst{ +private: + std::string repr_impl() const { return "masked_store"; } + masked_store_inst(value *ptr, value *v, value *mask, + const std::string &name, instruction *next); + +public: + // accessors + value *get_mask_operand() { return get_operand(2); } + // factory method + static masked_store_inst* create(value *ptr, value *v, value *mask, + const std::string &name = "", + instruction *next = nullptr); }; //===----------------------------------------------------------------------===// @@ -507,21 +501,6 @@ protected: using instruction::instruction; }; -class get_global_range_inst: public builtin_inst { -private: - get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next); - std::string repr_impl() const { return "get_global_range(" + std::to_string(axis_) + ")"; } - -public: - static instruction* create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size, - const std::string &name = "", - instruction *next = nullptr); - unsigned get_axis() const { return axis_; } - -private: - unsigned axis_; -}; - class get_range_id_inst: public builtin_inst { private: get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 3d894c802..538485366 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -71,16 +71,6 @@ private: const constant* size_; }; -class get_global_range_expression: public builtin_expression{ -public: - get_global_range_expression(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { } - ir::value* codegen(ir::module *) const; - -private: - const constant* size_; - const constant* axis_; -}; - class get_range_id_expression: public builtin_expression{ public: get_range_id_expression(node *axis): axis_((constant*)axis) { } diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 32b3c5ed4..645b0b51f 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT -%token GET_GLOBAL_RANGE GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST +%token GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ATOMIC_ADD ALLOC_CONST %start translation_unit %% @@ -120,8 +120,7 @@ identifier /* Built-in */ builtin_expression - : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); } - | GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } + : GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index 0fbaa52d2..83d11035d 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -44,7 +44,6 @@ using triton::lang::return_void; "fp32" { return return_impl(FP32, yytext); } "fp64" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } -"get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); } "get_range_id" { return return_impl(GET_RANGE_ID, yytext); } "__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); } "__atomic_exchg" { return return_impl(ATOMIC_EXCHG, yytext); } diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index aa7a930bb..8f0f1ef73 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -11,7 +11,7 @@ #include "triton/codegen/selection.h" #include "triton/codegen/tune.h" #include "triton/codegen/optimize_dot.h" -#include "triton/codegen/optimize_cse.h" +#include "triton/codegen/optimize_dce.h" #include "triton/codegen/optimize_trans.h" #include "triton/codegen/shmem_allocation.h" #include "triton/codegen/shmem_liveness.h" @@ -63,7 +63,7 @@ public: vectorize(&tune), selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target), optimize_dot(&tune), - optimize_cse(), + optimize_dce(), optimize_trans(), alignment_info(), reassociate(&tune, &alignment_info), @@ -72,14 +72,11 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); -// ir::print(module, std::cout); } void target_dependent(ir::module &module) { alignment_info.run(module); reassociate.run(module); - ir::print(module, std::cout); -// exit(EXIT_FAILURE); if(target_->is_gpu()){ shmem_info.run(module); shmem_liveness.run(module); @@ -87,6 +84,8 @@ public: shmem_barriers.run(module); } vectorize.run(module); + optimize_dce.run(module); +// ir::print(module, std::cout); } codegen::tune tune; @@ -97,7 +96,7 @@ public: codegen::vectorize vectorize; codegen::selection selection; codegen::optimize_dot optimize_dot; - codegen::optimize_cse optimize_cse; + codegen::optimize_dce optimize_dce; codegen::optimize_trans optimize_trans; codegen::alignment_info alignment_info; codegen::reassociate reassociate; diff --git a/lib/codegen/alignment_info.cpp b/lib/codegen/alignment_info.cpp index 7c40229a2..87df925df 100644 --- a/lib/codegen/alignment_info.cpp +++ b/lib/codegen/alignment_info.cpp @@ -109,8 +109,6 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){ if(!v->get_type()->is_tile_ty()) return cache(1); auto shapes = v->get_type()->get_tile_shapes(); - if(dynamic_cast(v)) - return cache(shapes[0]->get_value()); if(dynamic_cast(v)) return cache(shapes[0]->get_value()); if(auto *x = dynamic_cast(v)){ @@ -243,14 +241,6 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ int op = populate_starting_multiple(x->get_operand(0)); return cache(op); } - if(auto *x = dynamic_cast(v)){ - return cache(v->get_type()->get_tile_shapes()[0]->get_value()); - } -// if(auto *x = dynamic_cast(v)){ -// int value_true = populate_starting_multiple(x->get_value_true()); -// int value_false = populate_starting_multiple(x->get_value_false()); -// return cache(gcd(value_true, value_false)); -// } if(auto *x = dynamic_cast(v)){ // put a conservative initial value in phi node to avoid infinite recursion unsigned result = 1; @@ -313,7 +303,6 @@ void alignment_info::run(ir::module &mod) { for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ populate_max_contiguous(i); - std::cout << i->get_name() << " " << is_constant_.at(i).num_cst << " " << max_contiguous_.at(i) << " " << starting_multiple_.at(i) << std::endl; } } diff --git a/lib/codegen/optimize_cse.cpp b/lib/codegen/optimize_cse.cpp deleted file mode 100644 index b0c07a99e..000000000 --- a/lib/codegen/optimize_cse.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/module.h" -#include "triton/codegen/optimize_cse.h" - -namespace triton { -namespace codegen{ - - -void optimize_cse::run(ir::module &mod) { -} - -} -} diff --git a/lib/codegen/optimize_dce.cpp b/lib/codegen/optimize_dce.cpp new file mode 100644 index 000000000..d30bf4c1d --- /dev/null +++ b/lib/codegen/optimize_dce.cpp @@ -0,0 +1,60 @@ +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/module.h" +#include "triton/ir/cfg.h" +#include "triton/codegen/optimize_dce.h" + +namespace triton { +namespace codegen{ + + +void optimize_dce::run(ir::module &mod) { + std::list work_list; + std::set marked; + + // initialize work-list + for(ir::function *fn: mod.get_function_list()){ + std::vector rpo = ir::cfg::reverse_post_order(fn); + // iterate through blocks + for(ir::basic_block *block: rpo) + for(ir::instruction *i: block->get_inst_list()){ + if(dynamic_cast(i) || dynamic_cast(i) || dynamic_cast(i) + || dynamic_cast(i) || dynamic_cast(i)){ + work_list.push_back(i); + marked.insert(i); + } + } + } + + // mark -- ignore branches + while(!work_list.empty()){ + ir::instruction* current = work_list.back(); + work_list.pop_back(); + // mark instruction operands + for(ir::value* op: current->ops()) { + if(auto *i = dynamic_cast(op)) + if(marked.insert(i).second) + work_list.push_back(i); + } + // TODO: mark last intstruction of current's reverse-dominance frontier + } + + // sweep -- delete non-branch unmarked instructions + std::vector to_delete; + for(ir::function *fn: mod.get_function_list()){ + std::vector rpo = ir::cfg::reverse_post_order(fn); + // iterate through blocks + for(ir::basic_block *block: rpo) + for(ir::instruction *i: block->get_inst_list()){ + if(marked.find(i) == marked.end()) + to_delete.push_back(i); + } + } + + // delete + for(ir::instruction* i: to_delete) + i->erase_from_parent(); +} + +} +} diff --git a/lib/codegen/optimize_trans.cpp b/lib/codegen/optimize_trans.cpp index b6ad7cfd2..0fb96ac96 100644 --- a/lib/codegen/optimize_trans.cpp +++ b/lib/codegen/optimize_trans.cpp @@ -7,20 +7,18 @@ namespace codegen{ ir::value* optimize_trans::replace_phi(ir::value* value, - std::vector& to_delete, ir::builder& builder){ if(auto phi = dynamic_cast(value)) { // transpose operands std::vector incs; for(unsigned n = 0; n < phi->get_num_incoming(); n++) - incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder)); + incs.push_back(replace_phi(phi->get_incoming_value(n), builder)); // create phi for transposed values builder.set_insert_point(phi); ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name()); for(unsigned n = 0; n < phi->get_num_incoming(); n++) result->add_incoming(incs[n], phi->get_incoming_block(n)); phi->replace_all_uses_with(result); - to_delete.push_back(phi); return result; } else if(auto i = dynamic_cast(value)){ @@ -39,7 +37,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value, void optimize_trans::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); - std::vector to_delete; // iterate for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) @@ -56,15 +53,11 @@ void optimize_trans::run(ir::module &mod) { // trans(phi) -> phi(trans(), trans()...) if(dynamic_cast(op)){ - ir::value* new_phi = replace_phi(op, to_delete, builder); - to_delete.push_back(trans); + ir::value* new_phi = replace_phi(op, builder); trans->replace_all_uses_with(new_phi); } } } - // erase dead code - for(ir::instruction* i: to_delete) - i->erase_from_parent(); } } diff --git a/lib/codegen/reassociate.cpp b/lib/codegen/reassociate.cpp index fa7c256fd..bf36b2033 100644 --- a/lib/codegen/reassociate.cpp +++ b/lib/codegen/reassociate.cpp @@ -189,8 +189,6 @@ void reassociate::run(ir::module &mod) { // reassociate std::map infos; - std::map> re_ordered; - for(ir::function *fn: mod.get_function_list()){ std::vector rpo = ir::cfg::reverse_post_order(fn); // iterate through blocks @@ -259,11 +257,6 @@ void reassociate::run(ir::module &mod) { params_->copy(new_pz, pz); align_->copy(new_pz, pz); } - -// // reassociate pointer -// reassociate_ptr(pz, builder, offsets); - - } } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index caf666bfd..b4e40a3f2 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -32,7 +32,7 @@ void distributed_tile::init_indices() { current.push_back(axes_[d].values[id[d]]); size_t sz = indices_.size(); indices_[current] = sz; - values_[current] = UndefValue::get(ty_); + values_[current] = nullptr; ordered_indices_.push_back(current); id[0]++; while(id[k] == axes_[k].values.size()){ @@ -57,12 +57,17 @@ distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_ init_indices(); } -void distributed_tile::set_value(indices_t idx, Value *v) { - values_[idx] = v; +void distributed_tile::set_value(indices_t idx, Value *x) { + assert(x->getType() == ty_ && "cannot set a value of different type"); + Value *&result = values_[idx]; + assert(!result && "value cannot be set twice"); + result = x; } Value* distributed_tile::get_value(indices_t idx) { - return values_[idx]; + Value *result = values_.at(idx); + assert(result && "value has not been set"); + return result; } unsigned distributed_tile::get_linear_index(indices_t idx) { @@ -688,15 +693,15 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, } bool vectorize = dynamic_cast(v); distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); - tmap_.insert({v, T}); + bool is_inserted = tmap_.insert({v, T}).second; // constant range - if(dynamic_cast(v)){ + if(is_inserted && dynamic_cast(v)){ T->for_each([&](indices_t idx){ assert(idx.size() == 1); T->set_value(idx, idx[0]); }); } - if(dynamic_cast(v)){ + if(is_inserted && dynamic_cast(v)){ T->for_each([&](indices_t idx){ assert(idx.size() == 1); BinaryOperator *bin_add = dyn_cast(idx[0]); @@ -746,31 +751,41 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & LLVMContext &ctx = builder.getContext(); Function *fn = block->getParent(); // store - if(auto *x = dynamic_cast(ins)) { - distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand()); - tile *value = tmap_.at(x->get_value_operand()); - ir::value *mask = x->get_mask(); - if(mask) { - distributed_tile* preds = (distributed_tile*)tmap_.at(mask); - ptr->for_each([&](indices_t idx){ - BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); - BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); - builder.CreateCondBr(preds->get_value(idx), mask_then_bb, mask_done_bb); - builder.SetInsertPoint(mask_then_bb); - builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); - builder.CreateBr(mask_done_bb); - builder.SetInsertPoint(mask_done_bb); - }); - } - else { - ptr->for_each([&](indices_t idx){ - if(GetElementPtrInst *gep = dyn_cast(ptr->get_value(idx))) - if(BinaryOperator *binop = dyn_cast(*gep->idx_begin())){ - std::cout << isa(binop->getOperand(0)) << " " << isa(binop->getOperand(1)) << std::endl; - } - builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); - }); - } + if(auto *x = dynamic_cast(ins)){ + distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); + tile *scalars = tmap_.at(x->get_value_operand()); + ir::value *mask = x->get_mask_operand(); + distributed_tile* preds = (distributed_tile*)tmap_.at(mask); + ptrs->for_each([&](indices_t idx){ + Value *scalar = scalars->get_value(idx); + Value *ptr = ptrs->get_value(idx); + Value *pred = preds->get_value(idx); +// std::string offset = ""; +// if(GetElementPtrInst *gep = dyn_cast(ptr)) +// if(gep->getNumIndices() == 1) +// if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ +// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4); +// } +// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false); +// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;"; +// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true); +// builder.CreateCall(iasm, {pred, ptr, scalar}); + + BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); + BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); + builder.CreateCondBr(pred, mask_then_bb, mask_done_bb); + builder.SetInsertPoint(mask_then_bb); + builder.CreateStore(scalar, ptr); + builder.CreateBr(mask_done_bb); + builder.SetInsertPoint(mask_done_bb); + }); + } + else if(auto *x = dynamic_cast(ins)) { + distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); + tile *scalars = tmap_.at(x->get_value_operand()); + ptrs->for_each([&](indices_t idx){ + builder.CreateStore(scalars->get_value(idx), ptrs->get_value(idx)); + }); } else { if(auto *x = dynamic_cast(ins)){ @@ -837,14 +852,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & if(!ins->get_type()->is_tile_ty()) return; const auto& shapes = ins->get_type()->get_tile_shapes(); - // global_range - if(auto *x = dynamic_cast(ins)) { - Value *offset = tgt_->get_global_offset(module, builder, shapes[0]->get_value(), x->get_axis()); - result->for_each([&](indices_t idx){ - BinaryOperator *bin = static_cast(idx[0]); - result->set_value(idx, builder.CreateAdd(bin, offset)); - }); - } // nv_dynamic_range_idx_inst if(dynamic_cast(ins)){ result->for_each([&](indices_t idx){ @@ -855,49 +862,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & result->set_value(idx, res); }); } -// // mask -// else if(dynamic_cast(ins)) { -// distributed_tile* pred = (distributed_tile*)tmap_.at(ins->get_operand(0)); -// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(ins->get_result(0)); -// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(ins->get_result(1)); -// pred->for_each([&](indices_t idx){ -// BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); -// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else", fn); -// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); -// builder.CreateCondBr(pred->get_value(idx), mask_then_bb, mask_else_bb); -// builder.SetInsertPoint(mask_then_bb); -// builder.CreateBr(mask_done_bb); -// builder.SetInsertPoint(mask_else_bb); -// builder.CreateBr(mask_done_bb); -// builder.SetInsertPoint(mask_done_bb); -// pmap_.insert({{mask_tile_true, idx}, mask_then_bb}); -// pmap_.insert({{mask_tile_false, idx}, mask_else_bb}); -// last_block_.insert({{mask_tile_true, idx}, mask_done_bb}); -// last_block_.insert({{mask_tile_false, idx}, mask_done_bb}); -// }); -// } -// // merge -// else if(auto *merge = dynamic_cast(ins)) { -// distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true()); -// distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true()); -// distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false()); -// distributed_tile *value_tile_false = (distributed_tile*)tmap_.at(merge->get_value_false()); -// result->for_each([&](indices_t idx){ -// BasicBlock *block_true = pmap_.at({mask_tile_true, idx}); -// Value *value_true = value_tile_true->get_value(idx); -// BasicBlock *block_false = pmap_.at({mask_tile_false, idx}); -// Value *value_false = value_tile_false->get_value(idx); -// BasicBlock *block_done = last_block_.at({mask_tile_true, idx}); -// if(block_done->getTerminator()) -// builder.SetInsertPoint(block_done->getTerminator()); -// else -// builder.SetInsertPoint(block_done); -// PHINode *phi = builder.CreatePHI(value_true->getType(), 2); -// phi->addIncoming(value_true, block_true); -// phi->addIncoming(value_false,block_false); -// result->set_value(idx, phi); -// }); -// } // reshape else if(dynamic_cast(ins)) { ir::value* in = ins->get_operand(0); @@ -939,9 +903,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & in->for_each([&](indices_t idx){ unsigned linear = in->get_linear_index(idx); unsigned id = linear / vector_size; + Value *in_value = in->get_value(idx); if(linear % vector_size == 0) - packets[id] = result->get_value(idx); - packets[id] = builder.CreateInsertElement(packets.at(id), in->get_value(idx), linear % vector_size); + packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); + packets[id] = builder.CreateInsertElement(packets.at(id), in_value, linear % vector_size); }); result->for_each([&](indices_t idx){ unsigned linear = in->get_linear_index(idx); @@ -1017,8 +982,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & TB->set_return_mode(true); std::vector fc; + result->for_each([&](indices_t idx){ - fc.push_back(result->get_value(idx)); + fc.push_back(TC->get_value(idx)); +// fc.push_back(UndefValue::get(TC->get_value(idx)->getType())); }); Type *fp32_ty = builder.getFloatTy(); @@ -1076,10 +1043,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Value *hb = TB->get_value(idx_b); for(unsigned ii = 0; ii < pack_size_0_; ii++) for(unsigned jj = 0; jj < pack_size_1_; jj++){ - Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)); - Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)); - Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)); - Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)); + Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty); + Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty); + Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty); + Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty); std::vector idx = { (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc, @@ -1136,24 +1103,106 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & }); } } - else if(auto *ld = dynamic_cast(ins)){ + else if(auto *ld = dynamic_cast(ins)){ + // find vector size ir::value *ptr = ld->get_pointer_operand(); unsigned starting_multiple = axis_info_->get_starting_multiple(ptr); unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned vector_size = std::min(result->axis(0).contiguous, alignment); + distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); + distributed_tile *masks = (distributed_tile*)tmap_.at(ld->get_mask_operand()); + distributed_tile *false_values = (distributed_tile*)tmap_.at(ld->get_false_value_operand()); std::map packets; - distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); result->for_each([&](indices_t idx){ unsigned linear = result->get_linear_index(idx); unsigned id = linear / vector_size; - if(linear % vector_size == 0){ - Value *ptr = TP->get_value(idx); - ptr= builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), - ptr->getType()->getPointerAddressSpace())); + if(linear % vector_size == 0) { + Value *ptr = pointers->get_value(idx); + ConstantInt *cst = nullptr; + if(GetElementPtrInst *gep = dyn_cast(ptr)) + if(gep->getNumIndices() == 1){ + cst = dyn_cast(gep->idx_begin()); + } + + ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), + ptr->getType()->getPointerAddressSpace())); + Value *mask = masks->get_value(idx); + BasicBlock *current_bb = builder.GetInsertBlock(); + BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); + BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); + builder.CreateCondBr(mask, mask_then_bb, mask_done_bb); + builder.SetInsertPoint(mask_then_bb); + Value *result_then = builder.CreateLoad(ptr); + builder.CreateBr(mask_done_bb); + builder.SetInsertPoint(mask_done_bb); + Value *result = nullptr; + if(false_values){ + result = builder.CreatePHI(result_then->getType(), 2); + ((PHINode*)result)->addIncoming(result_then, mask_then_bb); + Value *result_false = false_values->get_value(idx); + if(vector_size > 1) + result_false = builder.CreateVectorSplat(vector_size, result_false); + ((PHINode*)result)->addIncoming(result_false, current_bb); + } + else + result = result_then; + +// std::string offset = ""; +// if(cst) +// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size); +// Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2); +// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); +// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false); +// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; +// if(false_value) +// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};"; +// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true); +// Value *result = builder.CreateCall(iasm, {mask, ptr}); + + packets[id] = result; + } + }); + // extract result element + result->for_each([&](indices_t idx){ + unsigned linear = result->get_linear_index(idx); + unsigned id = linear / vector_size; +// Value *tmp = builder.CreateExtractValue(packets.at(id), {(linear % vector_size) / 2}); +// Value *res = builder.CreateExtractElement(tmp, (linear % vector_size) % 2); +// result->set_value(idx, res); + result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); + }); + } + else if(auto *ld = dynamic_cast(ins)){ + // find vector size + ir::value *ptr = ld->get_pointer_operand(); + unsigned starting_multiple = axis_info_->get_starting_multiple(ptr); + unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); + unsigned alignment = std::min(starting_multiple, max_contiguous); + unsigned vector_size = std::min(result->axis(0).contiguous, alignment); + distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); + // vector loads + std::map packets; + result->for_each([&](indices_t idx){ + unsigned linear = result->get_linear_index(idx); + unsigned id = linear / vector_size; + if(linear % vector_size == 0) { + Value *ptr = pointers->get_value(idx); + ConstantInt *cst = nullptr; + if(GetElementPtrInst *gep = dyn_cast(ptr)) + if(gep->getNumIndices() == 1){ + cst = dyn_cast(gep->idx_begin()); + } + ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), + ptr->getType()->getPointerAddressSpace())); packets[id] = builder.CreateLoad(ptr); } - result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); + }); + // extract result element + result->for_each([&](indices_t idx){ + unsigned linear = result->get_linear_index(idx); + unsigned id = linear / vector_size; +// result->set_value(idx, builder.CreateExtractElement(packets.at(id), linear % vector_size)); }); } // element-wise diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 288eb4204..1da6240dd 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -106,9 +106,9 @@ void tune::init_c_graph(ir::instruction *v) { for(unsigned k = 0; k < v->get_num_results(); k++){ ir::value *result = v->get_result(k); for(unsigned i = 0; i < shapes.size(); i ++){ - for(ir::value* op: v->ops()){ + std::vector ops = v->ops(); + for(ir::value* op: ops) add_constraint({result, i}, {op, i}); - } } } } diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 23e62ae76..897a26402 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -123,14 +123,16 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(; - )" + a_ty_ + R"( a[)" + AS + R"(] = *pa; - )" + b_ty_ + R"( b[)" + BS + R"(] = *pb; + )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; + )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; for(int32 k = K; k > 0; k = k - TK){ c = dot()" + usea + ", " + useb + R"(, c); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; - a = *pa; - b = *pb; + int1 checka[)" + AS + R"(] = k > TK; + int1 checkb[)" + BS + R"(] = k > TK; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; } int32 rxc[TM] = ridx * TM + (0 ... TM); int32 ryc[TN] = ridy * TN + (0 ... TN); @@ -138,11 +140,10 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int1 checkc1[TN] = ryc < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - *pc = c; + @checkc *pc = c; } )"; - std::cout << res << std::endl; os << res; } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 551a55a20..4ff863666 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ - std::cout << source << std::endl; +// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index bff68e083..e58fd9924 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -85,20 +85,6 @@ value *builder::create_ret_void() { return insert(return_inst::create(ctx_)); } - -//===----------------------------------------------------------------------===// -// tile-level control-flow instructions -//===----------------------------------------------------------------------===// - -//value *builder::create_mask(value *pred, const std::string &name){ -// return insert(mask_inst::create(pred, name)); -//} - -//value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) { -// return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name)); -//} - - //===----------------------------------------------------------------------===// // cast instructions //===----------------------------------------------------------------------===// @@ -264,14 +250,22 @@ DEFINE_FCMP_INSTR(ONE, llvm::FCmpInst::FCMP_ONE) // load/store instructions //===----------------------------------------------------------------------===// -value *builder::create_load(value *arg, const std::string &name){ - return insert(load_inst::create(arg, name)); +value *builder::create_load(value *ptr, const std::string &name){ + return insert(load_inst::create(ptr, name)); } value *builder::create_store(value *ptr, value *val, const std::string &name){ return insert(store_inst::create(ptr, val, name)); } +value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){ + return insert(masked_load_inst::create(ptr, mask, false_value, name)); +} + +value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){ + return insert(masked_store_inst::create(ptr, val, mask, name)); +} + //===----------------------------------------------------------------------===// // tile instructions //===----------------------------------------------------------------------===// @@ -296,10 +290,6 @@ value *builder::create_downcast(value *arg, const std::string &name) { // built-in instructions //===----------------------------------------------------------------------===// -value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name) { - return insert(get_global_range_inst::create(ctx_, axis, size, name)); -} - value *builder::create_get_range_id(unsigned axis, const std::string &name) { return insert(get_range_id_inst::create(ctx_, axis, name)); } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index b7743c7d5..9537336fb 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -270,6 +270,7 @@ std::string cast_inst::repr_impl() const { } // TODO bool cast_inst::is_valid(op_t op, value *arg, type *ty) { + assert(arg->get_type()->is_tile_ty() == ty->is_tile_ty()); return true; } @@ -348,34 +349,6 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v set_operand(2, cond); } -// mask_inst -//mask_inst::mask_inst(value *pred, const std::string &name, instruction *next) -// : instruction(pred->get_type(), 1, 2, name, next) { -// set_operand(0, pred); -//} - -//mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *next) { -// return new mask_inst(pred, name, next); -//} - -//// merge_inst -//psi_inst::psi_inst(value *mask_true, value *value_true, -// value *mask_false, value *value_false, -// const std::string &name, instruction *next) -// : instruction(value_true->get_type(), 4, 1, name, next) { -// set_operand(0, mask_true); -// set_operand(1, value_true); -// set_operand(2, mask_false); -// set_operand(3, value_false); -//} - -//psi_inst* psi_inst::create(value *mask_true, value *value_true, -// value *mask_false, value *value_false, -// const std::string &name, instruction *next) { -// return new psi_inst(mask_true, value_true, mask_false, value_false, name, next); -//} - - //===----------------------------------------------------------------------===// // getelementptr_inst classes @@ -440,6 +413,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vectorget_scalar_ty(); type *pointee_ty = scalar_ty->get_pointer_element_ty(); @@ -448,43 +428,52 @@ type *load_inst::get_pointee_type(type *ty) { return pointee_ty; } -load_inst::load_inst(value *ptr, const std::string &name, instruction *next) - : unary_inst(get_pointee_type(ptr->get_type()), ptr, name, next), mask_(nullptr){ -} - -value *load_inst::get_mask() const { - return mask_; -} - -value *load_inst::set_mask(value *mask) { - mask_ = mask; - return this; +load_inst::load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next) + : io_inst(get_pointee_type(ptr->get_type()), 1 + num_extra_ops, 1, name, next) { + set_operand(0, ptr); } load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) { - return new load_inst(ptr, name, next); + return new load_inst(ptr, 0, name, next); } +// masked load +masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next) + : load_inst(ptr, 2, name, next) { + set_operand(1, mask); + set_operand(2, false_value); +} + +masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next) { + return new masked_load_inst(ptr, mask, false_value, name, next); +} + + // store -store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next) - : instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, 1, name, next), mask_(nullptr) { +store_inst::store_inst(value *ptr, value *val, unsigned num_extra_ops, + const std::string &name, instruction *next) + : io_inst(type::get_void_ty(ptr->get_type()->get_context()), 2 + num_extra_ops, 1, name, next) { set_operand(0, ptr); - set_operand(1, v); + set_operand(1, val); } -value *store_inst::get_mask() const { - return mask_; +store_inst* store_inst::create(value *ptr, value *val, + const std::string &name, instruction *next) { + return new store_inst(ptr, val, 0, name, next); } -value *store_inst::set_mask(value *mask) { - mask_ = mask; - return this; +// masked store +masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask, + const std::string &name, instruction *next) + : store_inst(ptr, val, 1, name, next) { + set_operand(2, mask); } -store_inst* store_inst::create(value *ptr, value *v, const std::string &name, instruction *next) { - return new store_inst(ptr, v, name, next); +masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) { + return new masked_store_inst(ptr, val, mask, name, next); } - //===----------------------------------------------------------------------===// // retile_inst classes //===----------------------------------------------------------------------===// @@ -636,19 +625,6 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value // builtin instructions //===----------------------------------------------------------------------===// -// get_global_range -get_global_range_inst::get_global_range_inst(type *ty, unsigned axis, - const std::string &name, instruction *next) - : builtin_inst(ty, 0, 1, name, next), axis_(axis) { - -} - -instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::tile_shapes_t::value_type size, - const std::string &name, instruction *next) { - type *int_ty = type::get_int32_ty(ctx); - type *tile_ty = tile_type::get(int_ty, {size}); - return new get_global_range_inst(tile_ty, axis, name, next); -} // get_range_id get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) diff --git a/lib/ir/print.cpp b/lib/ir/print.cpp index b1e43c0e2..4b7248bc6 100644 --- a/lib/ir/print.cpp +++ b/lib/ir/print.cpp @@ -35,12 +35,6 @@ void print(module &mod, std::ostream& os) { os << std::endl; for(ir::instruction *inst: block->get_inst_list()){ os << " "; - if(auto *x = dynamic_cast(inst)) - if(ir::value *mask = x->get_mask()) - os << "@" << get_name(mask, cnt++) << " "; - if(auto *x = dynamic_cast(inst)) - if(ir::value *mask = x->get_mask()) - os << "@" << get_name(mask, cnt++) << " "; unsigned num_results = inst->get_num_results(); for(unsigned i = 0; i < num_results; i++){ os << get_name(inst->get_result(i), cnt++); diff --git a/lib/lang/declaration.cpp b/lib/lang/declaration.cpp index 64f238171..dba439ce1 100644 --- a/lib/lang/declaration.cpp +++ b/lib/lang/declaration.cpp @@ -151,7 +151,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ } else if(expr_){ value = expr_->codegen(mod); - value = explicit_cast(mod->get_builder(), value, ty); + value = explicit_cast(mod->get_builder(), value, ty->get_scalar_ty()); implicit_broadcast(mod, ty, value); } value->set_name(name); diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index b21a3b4c7..6baa1f3b2 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -115,12 +115,6 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const { return res; } -// get_global_range -ir::value* get_global_range_expression::codegen(ir::module *mod) const { - ir::builder &builder = mod->get_builder(); - return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); -} - // get_range_id ir::value* get_range_id_expression::codegen(ir::module *mod) const { return mod->get_builder().create_get_range_id(axis_->value()); @@ -254,39 +248,24 @@ ir::value* cast_expression::codegen(ir::module *mod) const{ } /* Conditional expression */ -ir::value *conditional_expression::codegen(ir::module *mod) const{ +ir::value *conditional_expression::codegen(ir::module *mod) const { + ir::builder &builder = mod->get_builder(); + ir::value *mask = cond_->codegen(mod); + ir::value *true_value = true_value_->codegen(mod); + ir::value *false_value = false_value_->codegen(mod); + bool is_float, is_ptr, is_int, is_signed; + implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); + implicit_broadcast(mod, mask, true_value); + implicit_broadcast(mod, mask, false_value); + if(ir::load_inst* load = dynamic_cast(true_value)){ + load->erase_from_parent(); + return builder.create_masked_load(load->get_pointer_operand(), mask, false_value); + } + if(ir::load_inst* load = dynamic_cast(false_value)){ + load->erase_from_parent(); + return builder.create_masked_load(load->get_pointer_operand(), mask, true_value); + } throw std::runtime_error("not implemented"); -// ir::builder &builder = mod->get_builder(); -// ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list(); -// ir::value *pred = cond_->codegen(mod); -// ir::instruction *mask = (ir::instruction*)builder.create_mask(pred); -// /* true value */ -// ir::value *true_mask = mask->get_result(0); -// auto it_true_begin = instructions.end(); -// it_true_begin--; -// ir::value *true_value = true_value_->codegen(mod); -// implicit_broadcast(mod, pred, true_value); -// it_true_begin++; -// auto it_true_end = instructions.end(); -// for(auto it = it_true_begin; it != it_true_end; it++) -//// if(!dynamic_cast(*it)) -// (*it)->set_mask_pred(true_mask); -// /* false value */ -// ir::value *false_mask = mask->get_result(1); -// auto it_false_begin = instructions.end(); -// it_false_begin--; -// ir::value *false_value = false_value_->codegen(mod); -// implicit_broadcast(mod, pred, false_value); -// bool is_float, is_ptr, is_int, is_signed; -// implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); -// it_false_begin++; -// auto it_false_end = instructions.end(); -// for(auto it = it_false_begin; it != it_false_end; it++) -//// if(!dynamic_cast(*it)) -// (*it)->set_mask_pred(false_mask); -// /* psi */ -// ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); -// return result; } /* Assignment expression */ diff --git a/lib/lang/statement.cpp b/lib/lang/statement.cpp index ab0a55828..a768bf7b4 100644 --- a/lib/lang/statement.cpp +++ b/lib/lang/statement.cpp @@ -29,21 +29,35 @@ ir::value* compound_statement::codegen(ir::module* mod) const{ /* Expression statement */ ir::value* expression_statement::codegen(ir::module *mod) const{ ir::builder &builder = mod->get_builder(); - ir::value *expr = expr_->codegen(mod); - if(pred_ == nullptr) - return expr; - ir::value *pred = pred_->codegen(mod); - if(auto *x = dynamic_cast(expr)) - x->set_mask(pred); - else if(auto *x = dynamic_cast(expr)) - x->set_mask(pred); - else - expr = builder.create_select(pred, expr, ir::undef_value::get(expr->get_type())); + // get name if applicable + std::string name = ""; + ir::value *current = nullptr; if(assignment_expression *assignment = dynamic_cast(expr_)) - if(auto *named = dynamic_cast(assignment)){ - std::string name = named->lvalue()->id()->name(); - mod->set_value(name, expr); + if(const named_expression* named = dynamic_cast(assignment->lvalue())){ + name = named->id()->name(); + current = mod->get_value(name); } + // lower expression + ir::value *expr = expr_->codegen(mod); + // modify expression if predicated + if(pred_) { + ir::value *pred = pred_->codegen(mod); + if(!current) + current = ir::undef_value::get(expr->get_type()); + if(auto *x = dynamic_cast(expr)){ + x->erase_from_parent(); + expr = builder.create_masked_load(x->get_pointer_operand(), pred, current); + } + else if(auto *x = dynamic_cast(expr)){ + x->erase_from_parent(); + expr =builder.create_masked_store(x->get_pointer_operand(), x->get_value_operand(), pred); + } + else + expr = builder.create_select(pred, expr, current); + } + // update symbols table + if(!name.empty()) + mod->set_value(name, expr); return expr; }