From 5efdb7978e41d482f724fba9417c178a3451ad98 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 6 Aug 2019 16:21:20 -0700 Subject: [PATCH] more improvements and regressions --- examples/cpp/dot.cpp | 6 ++-- include/triton/ir/constant.h | 3 +- include/triton/ir/instructions.h | 14 +++++++--- include/triton/lang/expression.h | 3 +- include/triton/lang/parser.y | 3 +- lib/codegen/optimize_dot.cpp | 24 ++++++++-------- lib/codegen/optimize_trans.cpp | 18 +++++++++--- lib/codegen/selection.cpp | 16 +++++++++-- lib/codegen/tune.cpp | 48 +++++++++++++++++++++----------- lib/dnn/base.cpp | 4 ++- lib/dnn/dot.cpp | 14 ++++++---- lib/ir/instructions.cpp | 36 ++++++++++++++---------- lib/lang/expression.cpp | 12 +++++++- lib/runtime/jit.cpp | 6 ++-- 14 files changed, 138 insertions(+), 69 deletions(-) diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index d4f5adb6e..87bb739e2 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -26,7 +26,7 @@ struct perf_t { perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float NumericT; - std::string ty = "float"; + std::string ty = "half"; size_t dt_nbytes = sizeof(NumericT); triton::driver::context* context = stream->context(); std::vector hc(M*N); @@ -48,7 +48,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); // benchmark cublas // NumericT alpha = 1; // NumericT beta = 0; @@ -111,7 +111,7 @@ int main() { std::vector configs = { // {false, false, 8192, 512, 512}, // {false, true, 8192, 8192, 8192} - {false, true, 128, 128, 128}, + {true, true, 128, 128, 128}, // {false, true, 32768, 256, 512} // {true, false, 8192, 512, 512}, // {true, true, 8192, 512, 512} diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index ca44c6227..ce618d998 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -38,6 +38,7 @@ protected: public: virtual uint64_t get_value() const { return value_; } + virtual std::string repr() const { return std::to_string(get_value()); } static constant_int *get(type *ty, uint64_t value); protected: @@ -57,7 +58,7 @@ public: const std::vector& get_space() { return space_; } void set_space(const std::vector &space) { space_ = space; } uint64_t get_value() const { assert(has_value_); return value_; } - + std::string repr() const { return has_value_? std::to_string(value_) : "?" ;} private: std::vector space_; bool has_value_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index c0b176ebf..8bb46eb2a 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -584,12 +584,18 @@ private: class trans_inst: public builtin_inst { public: - ir::type* get_res_ty(ir::type* in); - std::vector get_default_perm(ir::type* ty); + ir::type* get_res_ty(ir::type* in, std::vector perm); + std::vector init_perm(ir::type* ty, const std::vector& perm); private: trans_inst(value *arg, const std::vector& perm, const std::string& name, instruction* next); - std::string repr_impl() const { return "trans"; } + std::string repr_impl() const { + std::string res = "trans<"; + for(ir::constant_int *x: perm_) + res += x->repr() + ","; + res[res.size()-1] = '>'; + return res; + } public: static instruction* create(value *arg, const std::vector& perm = {}, const std::string &name = "", instruction *next = nullptr); @@ -609,7 +615,7 @@ public: class reduce_inst: public builtin_inst { private: - static type* get_type(value *arg, unsigned axis); + static type* get_res_type(value *arg, unsigned axis); private: reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index f0dac3bc9..6823e8988 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -180,11 +180,12 @@ private: class trans_expression: public builtin_expression{ public: - trans_expression(node *arg): arg_(arg) {} + trans_expression(node *arg, node *perm): arg_(arg), perm_((list*)perm) {} ir::value* codegen(ir::module *mod) const; private: node* arg_; + const list* perm_; }; class sqrt_expression: public builtin_expression{ diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 0df37673b..c44a619e8 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -125,7 +125,8 @@ builtin_expression | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } - | TRANS '(' expression ')' { $$ = new trans_expression($3); } + | TRANS '(' expression ',' constant_expression_list ')' { $$ = new trans_expression($3, $5); } + | TRANS '(' expression ')' { $$ = new trans_expression($3, nullptr); } | REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);} | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } diff --git a/lib/codegen/optimize_dot.cpp b/lib/codegen/optimize_dot.cpp index e3ebfbcdb..904b1c6c3 100644 --- a/lib/codegen/optimize_dot.cpp +++ b/lib/codegen/optimize_dot.cpp @@ -8,7 +8,17 @@ namespace triton { namespace codegen{ inline bool is_trans(ir::value *v){ - return dynamic_cast(v) != nullptr; + auto *x = dynamic_cast(v); + if(!x) + return false; + std::vector perm = x->get_perm(); + std::vector ref; + ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context()); + for(size_t i = 0; i < perm.size(); i++) + ref.push_back(ir::constant_int::get(int32_ty, i)); + std::swap(ref[0], ref[1]); + // true is perm == ref + return std::equal(perm.begin(), perm.end(), ref.begin()); } inline bool is_hmma(ir::value *v){ @@ -28,7 +38,6 @@ inline bool is_hmma(ir::value *v){ void optimize_dot::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); - std::vector to_delete; // iterate for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) @@ -47,15 +56,12 @@ void optimize_dot::run(ir::module &mod) { ir::value *BB = B; if(trans_a){ AA = ((ir::trans_inst*)A)->get_operand(0); - to_delete.push_back((ir::instruction*)A); } if(trans_b){ BB = ((ir::trans_inst*)B)->get_operand(0); - to_delete.push_back((ir::instruction*)B); } ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b)); dot->replace_all_uses_with(dot_atbt); - to_delete.push_back(dot); } else{ // dot(op(a), trans(b)) @@ -63,28 +69,24 @@ void optimize_dot::run(ir::module &mod) { ir::value* BB = ((ir::trans_inst*)B)->get_operand(0); ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); dot->replace_all_uses_with(NT); - to_delete.push_back((ir::instruction*)B); - to_delete.push_back(dot); } // dot(op(a), b) if(!trans_b){ + // create permutations size_t size = B->get_type()->get_tile_shapes().size(); std::vector perm(size); ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context()); for(size_t i = 0; i < size; i++) perm[i] = ir::constant_int::get(int32_ty, i); std::swap(perm[0], perm[1]); + // replace NN -> NT (trans) ir::value* BB = builder.create_trans(B, perm); ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D)); dot->replace_all_uses_with(NT); - to_delete.push_back(dot); } } } } - - for(ir::instruction* i: to_delete) - i->erase_from_parent(); } } diff --git a/lib/codegen/optimize_trans.cpp b/lib/codegen/optimize_trans.cpp index 16c4605d7..b1e0dc4b9 100644 --- a/lib/codegen/optimize_trans.cpp +++ b/lib/codegen/optimize_trans.cpp @@ -42,22 +42,32 @@ void optimize_trans::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction* i: block->get_inst_list()){ - // filter transposition + // transposition if(auto trans = dynamic_cast(i)) { auto users = trans->get_users(); auto ops = trans->ops(); if(users.size() > 1 || ops.size() > 1) continue; ir::value* op = *ops.begin(); - // chains of transpositions - // TODO - + // todo: chains of transpositions // trans(phi) -> phi(trans(), trans()...) if(dynamic_cast(op)){ ir::value* new_phi = replace_phi(op, builder, trans->get_perm()); trans->replace_all_uses_with(new_phi); } } + // reductions + if(auto x = dynamic_cast(i)) { + ir::constant_int *one = ir::constant_int::get(ir::type::get_int32_ty(i->get_type()->get_context()), 1); + ir::value *arg = x->get_operand(0); + auto shapes = arg->get_type()->get_tile_shapes(); + if(shapes[x->get_axis()] == one){ + builder.set_insert_point(x); + ir::value* new_red = builder.create_reshape(arg, x->get_type()->get_tile_shapes()); + x->replace_all_uses_with(new_red); + } + } + } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 3b973de71..8ec454842 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -996,8 +996,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & distributed_tile *TC = (distributed_tile*)tmap_.at(C); Type *c_ty = llvm_type(C->get_type()->get_scalar_ty(), ctx); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); + auto A_shapes = A->get_type()->get_tile_shapes(); size_t red_axis = dot->is_a_trans() ? 0 : 1; - unsigned NK = A->get_type()->get_tile_shapes()[red_axis]->get_value(); + unsigned NK = A_shapes[red_axis]->get_value(); if(NK != 1) { shared_tile *TA = (shared_tile*)tmap_.at(A); @@ -1008,18 +1009,27 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & result->for_each([&](indices_t idx){ Value *res = TC->get_value(idx); for(unsigned K = 0; K < NK; ++K){ - indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]}; - indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]}; + // input indices + indices_t a_idx = {idx[0], builder.getInt32(K)}; + indices_t b_idx = {builder.getInt32(K), idx[1]}; if(AT) std::swap(a_idx[0], a_idx[1]); if(BT) std::swap(b_idx[0], b_idx[1]); + // add batching dimension + for(size_t i = 2; i < idx.size(); i++){ + a_idx.insert(a_idx.end(), idx[i]); + b_idx.insert(b_idx.end(), idx[i]); + } + // load value Value *a = TA->get_value(a_idx); Value *b = TB->get_value(b_idx); if(a->getType() != c_ty) a = builder.CreateFPCast(a, c_ty); if(b->getType() != c_ty) b = builder.CreateFPCast(b, c_ty); +// a = ConstantFP::get(builder.getFloatTy(), 1); +// b = ConstantFP::get(builder.getFloatTy(), 1); res = builder.CreateCall(f_mul_add, {a, b, res}); } result->set_value(idx, res); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 09017e978..bc4c7118d 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -67,8 +67,6 @@ void tune::init_c_graph(ir::instruction *v) { continue; add_constraint({reduce, current++}, {arg, i}); } -// add_constraint({reduce, 0}, {arg, 0}); -// add_constraint({reduce, 1}, {arg, 1}); return; } else @@ -115,7 +113,7 @@ void tune::init_c_graph(ir::instruction *v) { } } // Matrix multiplication - else if(dynamic_cast(v)){ + else if(auto *x = dynamic_cast(v)){ ir::value *A = v->get_operand(0); ir::value *B = v->get_operand(1); ir::value *D = v->get_operand(2); @@ -124,8 +122,8 @@ void tune::init_c_graph(ir::instruction *v) { for(unsigned i = 2; i < shapes.size(); i++){ if(shapes[i] == one) static_params_.insert({{v, i}, 1}); - add_constraint({v, i}, {A, i}); - add_constraint({v, i}, {B, i}); +// add_constraint({v, i}, {A, i}); +// add_constraint({v, i}, {B, i}); } } // Element-wise @@ -268,35 +266,53 @@ void tune::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()){ + + if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN) continue; - if(auto *ld = dynamic_cast(i)) + + if(auto *x = dynamic_cast(i)) if(i->get_type()->is_tile_ty()){ - ir::type *ptr_ty = ld->get_pointer_operand()->get_type()->get_scalar_ty(); + ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty(); size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 1, 1)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp; } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 1, 1)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 1, 1)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 4)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 4)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } } + + // initialize grids + for(ir::function *fn: mod.get_function_list()){ + std::map references; + create_grids(grids_, references, fn); + } + + for(ir::instruction *i: grids_){ + auto shapes = i->get_type()->get_tile_shapes(); + for(size_t k = 0; k < shapes.size(); k++) + if(shapes[k]->get_value() == 1) { + if(fragments_.at({i, k}) == STRIDED_SCAN){ + params_.at(i).at("nts.d" + std::to_string(k))->set_value(1); + params_.at(i).at("mts.d" + std::to_string(k))->set_value(1); + } + if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){ + params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1); + params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1); + } + } + } } void tune::init(ir::module &mod) { - for(ir::function *fn: mod.get_function_list()){ - // initialize grids - std::map references; - create_grids(grids_, references, fn); - } - // number of threads num_threads_ = get_req_num_threads(grids_.front()); } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index ebbe699c1..1ad741240 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -64,7 +64,9 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v else{ // params_t params = heuristics(); // params_t params = jit->get_valid(name_.c_str(), src.c_str()); - params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1}; +// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT +// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN + params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index a43ea1ca4..83798921a 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -74,22 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; - std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4"; - std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4"; + std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1"; + std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT_ ? "trans(xa)" : "xa"; - std::string useb = BT_ ? "trans(xb)" : "xb"; + std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa"; + std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)"; if(AT_){ std::swap(AS0, AS1); std::swap(XAS0, XAS1); + std::swap(XAS1, XAS2); std::swap(bca0, bca1); std::swap(lda0, lda1); } if(BT_){ std::swap(BS0, BS1); + std::swap(XBS1, XBS2); std::swap(XBS0, XBS1); std::swap(bcb0, bcb1); std::swap(ldb0, ldb1); @@ -98,7 +100,7 @@ void dot::triton_c_src(std::ostream &os) const { std::string BS = BS0 + ", " + BS1; std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; - std::string XCS = "TM, TN, 4"; + std::string XCS = "TM, TN, 1"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = @@ -146,7 +148,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, } )"; - std::cout << res << std::endl; +// std::cout << res << std::endl; os << res; } diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 58a81cd3b..7ae5b73ec 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -482,7 +482,8 @@ std::string retile_inst::shape_suffix(ir::type* ty){ std::string res = "["; const auto& shapes = ty->get_tile_shapes(); for(unsigned i = 0; i < shapes.size(); i++){ - res += std::to_string(ty->get_tile_shapes()[i]->get_value()); + ir::constant_int *shape_i = ty->get_tile_shapes()[i]; + res += shape_i->repr(); if(i < shapes.size() - 1) res += ", "; } @@ -566,26 +567,33 @@ instruction *dot_inst::create_tt(value *A, value *B, value *C, // trans instructions //===----------------------------------------------------------------------===// -ir::type* trans_inst::get_res_ty(ir::type* ty) { - auto shapes = ty->get_tile_shapes(); - std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end()); - return tile_type::get(ty->get_scalar_ty(), shapes); +ir::type* trans_inst::get_res_ty(ir::type* ty, std::vector perm) { + // get argument shapes + ir::tile_type::tile_shapes_t arg_shapes = ty->get_tile_shapes(); + // permutate argument shapes + perm = init_perm(ty, perm); + ir::tile_type::tile_shapes_t res_shapes = arg_shapes; + for(int i = 0; i < perm.size(); i++) + res_shapes[i] = arg_shapes[perm[i]->get_value()]; + // construct type + return tile_type::get(ty->get_scalar_ty(), res_shapes); } -std::vector trans_inst::get_default_perm(ir::type* ty) { +std::vector trans_inst::init_perm(ir::type* ty, const std::vector& perm) { + if(!perm.empty()) + return perm; auto size = ty->get_tile_shapes().size(); ir::type* int32_ty = type::get_int32_ty(ty->get_context()); std::vector result; - for(size_t i = 0; i < size; i++) - result.push_back(ir::constant_int::get(int32_ty, i + 1 % size)); + result.push_back(ir::constant_int::get(int32_ty, size - 1)); + for(int i = 0; i < size - 1; i++) + result.push_back(ir::constant_int::get(int32_ty, i)); return result; } trans_inst::trans_inst(value *arg, const std::vector& perm, const std::string &name, instruction *next) - : builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) { - perm_ = perm; - if(perm_.empty()) - perm_ = get_default_perm(arg->get_type()); + : builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) { + perm_ = init_perm(arg->get_type(), perm); auto size = arg->get_type()->get_tile_shapes().size(); assert(perm_.size() == size); set_operand(0, arg); @@ -615,7 +623,7 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction //===----------------------------------------------------------------------===// // reduce instructions //===----------------------------------------------------------------------===// -type* reduce_inst::get_type(value *arg, unsigned axis) { +type* reduce_inst::get_res_type(value *arg, unsigned axis) { ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); shapes.erase(shapes.begin() + axis); type *scalar_ty = arg->get_type()->get_scalar_ty(); @@ -626,7 +634,7 @@ type* reduce_inst::get_type(value *arg, unsigned axis) { } reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) - : builtin_inst(get_type(arg, axis), 1, 1, name, next), + : builtin_inst(get_res_type(arg, axis), 1, 1, name, next), axis_(axis){ set_operand(0, arg); } diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index c54179943..acbfaf6f6 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -203,7 +203,17 @@ ir::value* select_expression::codegen(ir::module *mod) const { // trans ir::value* trans_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_trans(arg_->codegen(mod)); + // shapes + std::vector perm; + if(perm_) { + for(expression *expr: perm_->values()){ + ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); + if(shape == nullptr) + throw std::runtime_error("tile shapes must be constant expressions"); + perm.push_back(shape); + } + } + return mod->get_builder().create_trans(arg_->codegen(mod), perm); } // sqrt diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 86102a460..1f6a60ccd 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools -// ThreadPool pool(nthreads); + ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function -// pool.enqueue(f,values); - f(values); + pool.enqueue(f,values); +// f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return;