diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 3f7e5686c..b480da5f0 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -32,6 +32,7 @@ typedef std::vector indices_t; struct distributed_axis { size_t contiguous; std::vector values; + llvm::Value* thread_id; }; class tile { diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 1921814c9..079a79e40 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -134,7 +134,7 @@ public: value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); - value *create_reduce(value *A, const std::string &name = ""); + value *create_reduce(value *A, unsigned axis, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 37692d617..3cc86da26 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -605,11 +605,18 @@ public: class reduce_inst: public builtin_inst { private: - reduce_inst(value* arg, const std::string& name, instruction* next); + static type* get_type(value *arg, unsigned axis); + +private: + reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); std::string repr_impl() const { return "reduce"; } public: - static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); + static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr); + unsigned get_axis() const { return axis_; } + +private: + unsigned axis_; }; class select_inst: public builtin_inst { diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 13894d18a..f0dac3bc9 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -134,6 +134,16 @@ private: const expression *C_; }; +class reshape_expression: public builtin_expression{ +public: + reshape_expression(node *arg, node *shapes): arg_(arg), shapes_((list*)shapes) { } + ir::value* codegen(ir::module *) const; + +private: + const node *arg_; + const list* shapes_; +}; + class max_expression: public builtin_expression{ public: max_expression(node* x, node* y) @@ -188,11 +198,12 @@ private: class reduce_expression: public builtin_expression{ public: - reduce_expression(node *arg): arg_(arg) {} + reduce_expression(node *arg, node *axis): arg_(arg), axis_((constant*)axis) {} ir::value* codegen(ir::module *mod) const; private: node* arg_; + constant* axis_; }; class indexing_expression: public postfix_expression{ diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index cd2c8941b..1a4b26633 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT -%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST +%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE %start translation_unit %% @@ -126,13 +126,14 @@ builtin_expression | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); } | TRANS '(' expression ')' { $$ = new trans_expression($3); } - | REDUCE_SUM '(' expression ')' { $$ = new reduce_expression($3);} + | REDUCE_SUM '(' expression ',' constant ')' { $$ = new reduce_expression($3, $5);} | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); } | ATOMIC_EXCH '(' expression ',' expression ')' { $$ = new atomic_exch_expression($3, $5); } | ATOMIC_ADD '(' expression ',' expression ')' { $$ = new atomic_add_expression($3, $5); } + | RESHAPE '(' expression ',' primary_expression_list ')' { $$ = new reshape_expression($3, $5); } ; /* Primary */ diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index fc791ae94..1aaf40a57 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -30,18 +30,18 @@ using triton::lang::return_void; "for" { return return_impl(FOR, yytext); } "while" { return return_impl(WHILE, yytext); } "void" { return return_impl(VOID, yytext); } -"uchar" { return return_impl(UCHAR, yytext); } -"ushort" { return return_impl(USHORT, yytext); } -"uint" { return return_impl(UINT, yytext); } -"ulong" { return return_impl(ULONG, yytext); } -"bool" { return return_impl(BOOL, yytext); } -"char" { return return_impl(CHAR, yytext); } -"short" { return return_impl(SHORT, yytext); } -"int" { return return_impl(INT, yytext); } -"long" { return return_impl(LONG, yytext); } -"half" { return return_impl(HALF, yytext); } -"float" { return return_impl(FLOAT, yytext); } -"double" { return return_impl(DOUBLE, yytext); } +"uchar" { return return_impl(UINT8, yytext); } +"ushort" { return return_impl(UINT16, yytext); } +"uint" { return return_impl(UINT32, yytext); } +"ulong" { return return_impl(UINT64, yytext); } +"bool" { return return_impl(INT1, yytext); } +"char" { return return_impl(INT8, yytext); } +"short" { return return_impl(INT16, yytext); } +"int" { return return_impl(INT32, yytext); } +"long" { return return_impl(INT64, yytext); } +"half" { return return_impl(FP16, yytext); } +"float" { return return_impl(FP32, yytext); } +"double" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } "get_range_id" { return return_impl(GET_RANGE_ID, yytext); } "get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } @@ -49,6 +49,7 @@ using triton::lang::return_void; "__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); } "__atomic_add" { return return_impl(ATOMIC_ADD, yytext); } "__sum" { return return_impl(REDUCE_SUM, yytext); } +"__reshape" { return return_impl(RESHAPE, yytext); } "sqrt" { return return_impl(SQRT, yytext); } "dot" { return return_impl(DOT, yytext); } "max" { return return_impl(MAX, yytext); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 7e981672a..dc8980a28 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -80,9 +80,10 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) { void distributed_tile::for_each(std::function fn) { - for(unsigned i = 0; i < ordered_indices_.size(); i++) + for(unsigned i = 0; i < ordered_indices_.size(); i++){ if(i % vector_size_ == 0) fn(ordered_indices_[i]); + } } /* Shared Tile */ @@ -498,15 +499,15 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id Value *warp_size_k = builder.getInt32(warp_size[k]); Value *contiguous_k = builder.getInt32(contiguous[k]); Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k)); - thread_id = builder.CreateMul(thread_id, contiguous_k); + Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k); unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k]; unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ unsigned offset = n / contiguous[k] * per_block + n % contiguous[k]; - idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); + idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); } - axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list}; + axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id}; } } else { @@ -671,7 +672,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, shapes[0] += pad; Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); // create shared tile - if(buffer_info_->is_shared(v)){ + if(buffer_info_->is_shared(v) && !dynamic_cast(v)){ // shared copy PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); // phi-node (double-buffering) @@ -825,88 +826,72 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & } if(auto *x = dynamic_cast(ins)){ std::map partial; - distributed_tile* op = (distributed_tile*)tmap_.at(ins->get_operand(0)); - size_t axis = 0; - unsigned num_warps = params_->get_num_threads() / 32; - std::vector shapes = op->get_shapes(); - shapes.erase(shapes.begin() + axis); - if(shapes.empty()) - shapes.push_back(1); + ir::value *op = ins->get_operand(0); + distributed_tile* op_tile = (distributed_tile*)tmap_.at(op); + unsigned axis = x->get_axis(); // reduce within thread - op->for_each([&](indices_t idx){ + op_tile->for_each([&](indices_t idx) { indices_t pidx = idx; pidx.erase(pidx.begin() + axis); - if(pidx.empty()) - pidx.push_back(builder.getInt32(0)); - Value *current = op->get_value(idx); + Value *current = op_tile->get_value(idx); + // current partial result is not initialized -- create if(partial.find(pidx) == partial.end()) partial[pidx] = current; + // current partial result is initialized -- accumulate else partial[pidx] = builder.CreateFAdd(partial[pidx], current); }); - // reduce within warp - Value *shfl = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_shfl_sync_bfly_f32); - for (int i = 16; i > 0; i >>= 1) - for(auto& x: partial) - { - Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), x.second, - builder.getInt32(i), - builder.getInt32(0x1f)}); - x.second = builder.CreateFAdd(x.second, rhs); - } - - // reduce within block - Value *tid = tgt_->get_local_id(module, builder, 0); - BasicBlock *partial_reduce_do = BasicBlock::Create(ctx, "partial_reduce_do", fn); - BasicBlock *partial_reduce_done = BasicBlock::Create(ctx, "partial_reduce_done", fn); - Value *id_in_warp = builder.CreateURem(tid, builder.getInt32(32)); - Value *warp_id = builder.CreateUDiv(tid, builder.getInt32(32)); - builder.CreateCondBr(builder.CreateICmpEQ(id_in_warp, builder.getInt32(0)), - partial_reduce_do, partial_reduce_done); - builder.SetInsertPoint(partial_reduce_do); + // reduce within blocks unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); - Type *ptr_ty = PointerType::get(builder.getFloatTy(), addr_space); - Value *sh_mem_ptr = builder.CreateBitCast(sh_mem_ptr_, ptr_ty); - for(auto& x: partial){ - Value *offset = shared_tile::shared_offset(builder, shapes, x.first); - offset = builder.CreateAdd(offset, builder.CreateMul(warp_id, builder.getInt32(shapes[0]))); - Value *write_ptr = builder.CreateGEP(sh_mem_ptr, offset); - builder.CreateStore(x.second, write_ptr); - } - builder.CreateBr(partial_reduce_done); - builder.SetInsertPoint(partial_reduce_done); + Type *res_ty = builder.getFloatTy(); + Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); + unsigned depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value(); + for(auto& x: partial) { + // current element being computed + Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id; + Value *&result = x.second; + indices_t write_idx = x.first; + write_idx.insert(write_idx.begin() + axis, lane); + // shared memory write pointer + Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx); + Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); + // initialize shared memory + builder.CreateStore(result, write_ptr); + // build result + for(unsigned i = depth/2; i > 0; i >>= 1){ + // current indices + indices_t current(write_idx.size(), builder.getInt32(0)); + current[axis] = builder.getInt32(i); + // shared memory offset + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current); + Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); + read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); + // shared memory read pointer + Value *read_ptr = builder.CreateGEP(write_ptr, read_offset); + tgt_->add_barrier(module, builder); + Value *next = builder.CreateLoad(read_ptr); + // accumulate + result = builder.CreateFAdd(result, next); + // write back + builder.CreateStore(result, write_ptr); + } - // Final reduction with the first warp - tgt_->add_barrier(module, builder); - BasicBlock *final_reduce_do = BasicBlock::Create(ctx, "final_reduce_do", fn); - BasicBlock *final_reduce_done = BasicBlock::Create(ctx, "final_reduce_done", fn); - builder.CreateCondBr(builder.CreateICmpEQ(warp_id, builder.getInt32(0)), - final_reduce_do, final_reduce_done); - builder.SetInsertPoint(final_reduce_do); - Value *read_ptr = builder.CreateGEP(sh_mem_ptr, tid); - BasicBlock *read_shmem_do = BasicBlock::Create(ctx, "read_shmem_do", fn); - BasicBlock *read_shmem_done = BasicBlock::Create(ctx, "read_shmem_done", fn); - builder.CreateCondBr(builder.CreateICmpULT(id_in_warp, builder.getInt32(num_warps)), - read_shmem_do, read_shmem_done); - builder.SetInsertPoint(read_shmem_do); - Value *loaded= builder.CreateLoad(read_ptr); - builder.CreateBr(read_shmem_done); - builder.SetInsertPoint(read_shmem_done); - Value *result = builder.CreatePHI(loaded->getType(), 2); - ((PHINode*)result)->addIncoming(ConstantFP::get(loaded->getType(), (double)0), final_reduce_do); - ((PHINode*)result)->addIncoming(loaded, read_shmem_do); - for (int i = params_->get_num_threads() / 64; i > 0; i >>= 1){ - Value *rhs = builder.CreateCall(shfl, {builder.getInt32(0xffffffff), result, - builder.getInt32(i), builder.getInt32(0x1f)}); - result = builder.CreateFAdd(result, rhs); + // result is on the first lane of shared memory + indices_t final = write_idx; + final[axis] = builder.getInt32(0); + Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final); + Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); + tgt_->add_barrier(module, builder); + result = builder.CreateLoad(read_ptr); + if(tmap_.find(ins) == tmap_.end()) + vmap_[ins] = result; + else{ + distributed_tile *ti = (distributed_tile*)tmap_[ins]; + ti->set_value(x.first, result); + } } - builder.CreateStore(result, read_ptr); - builder.CreateBr(final_reduce_done); - builder.SetInsertPoint(final_reduce_done); - tgt_->add_barrier(module, builder); - vmap_[ins] = builder.CreateLoad(sh_mem_ptr); return; } tile *ti = tmap_[ins]; diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index b4a903c1a..641170215 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -43,15 +43,16 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) { unsigned shmem_allocation::get_num_bytes(ir::value *x) { unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; - if(dynamic_cast(x)){ - size_t shape = 1; - if(x->get_type()->is_tile_ty()){ - auto shapes = x->get_type()->get_tile_shapes(); - for(auto x: shapes) - shape *= x->get_value(); - } - size_t n_warps = params_->get_num_threads() / 32; - return shape * num_bytes * n_warps; + if(auto *red = dynamic_cast(x)){ + size_t axis = red->get_axis(); + ir::value *op = red->get_operand(0); + auto shapes = op->get_type()->get_tile_shapes(); + shapes.erase(shapes.begin() + axis); + size_t num_elements = 1; + for(auto x: shapes) + num_elements *= x->get_value(); + size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value(); + return num_elements * num_bytes * depth; } unsigned pad = is_ld_padded(x); if(pad > 0){ diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 1a7ec94e5..db3ed1c81 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -58,8 +58,19 @@ void tune::init_c_graph(ir::instruction *v) { shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); else if(auto *downcast = dynamic_cast(v)) return; - else if(auto *reduce = dynamic_cast(v)) + else if(auto *reduce = dynamic_cast(v)) { + unsigned axis = reduce->get_axis(); + ir::value *arg = reduce->get_operand(0); + auto in_shapes = arg->get_type()->get_tile_shapes(); + unsigned current = 0; + for(unsigned i = 0; i < in_shapes.size(); i++){ + if(i == axis) + continue; +// std::cout << arg->get_name() << " " << v->get_name() << std::endl; + add_constraint({reduce, current++}, {arg, i}); + } return; + } else shapes = v->get_type()->get_tile_shapes(); // Reshape @@ -74,8 +85,10 @@ void tune::init_c_graph(ir::instruction *v) { static_params_.insert({{v, i}, 1}); else if(!is_skewed && is_same) add_constraint({v, i}, {op, current++}); - else + else{ is_skewed = true; + add_constraint({v, i}, {v, i}); + } } } // Splat @@ -137,6 +150,7 @@ tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){ } void tune::connected_components(node_t x, const std::vector mps, const std::vector prefixes, std::set &nodes, graph_t &graph, unsigned group_id) { +// std::cout << "connected component: " << x.first->get_name() << " " << x.second << std::endl; groups_[x.first].insert({x.second, group_id}); if(nodes.find(x) != nodes.end()){ nodes.erase(x); @@ -190,6 +204,7 @@ std::map tune::get_params(ir::instruction* i) } unsigned tune::get_param_group(ir::value *value, unsigned ax) { +// std::cout << "group? " << value->get_name() << " " << ax << std::endl; unsigned result = groups_.at(value).at(ax); return result; } diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index dcc9d6a4e..e5143755e 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -71,7 +71,7 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker void batchnorm_forward::triton_c_src(std::ostream &os) const { os << R"( -const tunable int TM = {32, 64, 128}; +const tunable int TM = {128}; void batchnorm_forward(float *Y, float *M, float *V, restrict read_only float *X, @@ -94,7 +94,7 @@ void batchnorm_forward(float *Y, float *M, float *V, px = px + TM; } float *pm = M + c; - float m = __sum(mean) * rcpDHWN; + float m = __sum(mean, 0) * rcpDHWN; *pm = m; float var[TM] = 0; @@ -105,7 +105,7 @@ void batchnorm_forward(float *Y, float *M, float *V, var = var + x*x; px = px + TM; } - float v = __sum(var) * rcpDHWN; + float v = __sum(var, 0) * rcpDHWN; float *pv = V + c; *pv = v; float rstdg = 1 / sqrt(v + eps) * g; @@ -167,7 +167,7 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke void batchnorm_backward::triton_c_src(std::ostream &os) const { os << R"( -const tunable int TM = {32, 64, 128}; +const tunable int TM = {128}; void batchnorm_backward(float *DX, float *DG, float *DB, restrict read_only float *DY, @@ -199,8 +199,8 @@ void batchnorm_backward(float *DX, float *DG, float *DB, px = px + TM; pdy = pdy + TM; } - float sdg = __sum(dg); - float sdb = __sum(db); + float sdg = __sum(dg, 0); + float sdb = __sum(db, 0); float *pdg = DG + c; float *pdb = DB + c; *pdg = sdg; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index 77c099827..a76c5e593 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -322,8 +322,8 @@ value *builder::create_sqrt(value *A, const std::string &name) { return insert(sqrt_inst::create(A, name)); } -value *builder::create_reduce(value *A, const std::string &name) { - return insert(reduce_inst::create(A, name)); +value *builder::create_reduce(value *A, unsigned axis, const std::string &name) { + return insert(reduce_inst::create(A, axis, name)); } value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){ diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index a29c11914..e6e85ff85 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -597,13 +597,24 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction //===----------------------------------------------------------------------===// // reduce instructions //===----------------------------------------------------------------------===// -reduce_inst::reduce_inst(value *arg, const std::string &name, instruction *next) - : builtin_inst(arg->get_type()->get_scalar_ty(), 1, 1, name, next) { +type* reduce_inst::get_type(value *arg, unsigned axis) { + ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); + shapes.erase(shapes.begin() + axis); + type *scalar_ty = arg->get_type()->get_scalar_ty(); + if(shapes.size() == 0) + return scalar_ty; + else + return tile_type::get(scalar_ty, shapes); +} + +reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) + : builtin_inst(get_type(arg, axis), 1, 1, name, next), + axis_(axis){ set_operand(0, arg); } -instruction* reduce_inst::create(value *arg, const std::string &name, instruction *next) { - return new reduce_inst(arg, name, next); +instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) { + return new reduce_inst(arg, axis, name, next); } diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 15e66607a..470f0b3cd 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -161,6 +161,21 @@ ir::value* matmul_expression::codegen(ir::module *mod) const { return mod->get_builder().create_dot(A, B, C); } +// reshape +ir::value* reshape_expression::codegen(ir::module *mod) const { + // arg + ir::value *arg = arg_->codegen(mod); + // shapes + ir::type::tile_shapes_t shapes; + for(expression *expr: shapes_->values()){ + ir::constant_int *shape = dynamic_cast(expr->codegen(mod)); + assert(shape); + shapes.push_back(shape); + } + // return + return mod->get_builder().create_reshape(arg, shapes); +} + // min ir::value* min_expression::codegen(ir::module *mod) const { ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod); @@ -198,7 +213,7 @@ ir::value* sqrt_expression::codegen(ir::module *mod) const { // reduce ir::value* reduce_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_reduce(arg_->codegen(mod)); + return mod->get_builder().create_reduce(arg_->codegen(mod), axis_->value()); } /* Postfix expression */ diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 1f6a60ccd..86102a460 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools - ThreadPool pool(nthreads); +// ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function - pool.enqueue(f,values); -// f(values); +// pool.enqueue(f,values); + f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return;