[codegen] some cleaning for batched matmul
This commit is contained in:
@@ -40,7 +40,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
|
||||
stream->write(da, true, 0, ha);
|
||||
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
@@ -77,7 +77,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
std::vector<NumericT> rc(hc.size());
|
||||
dot.cpu_ref(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@@ -111,8 +111,8 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
{false, true, 128, 128, 128},
|
||||
{false, true, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
// {true, true, 128, 128, 128}
|
||||
|
@@ -153,8 +153,8 @@ private:
|
||||
alignment_info *axis_info_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_, *offset_a_z_;
|
||||
llvm::Value *offset_b_j_, *offset_b_k_, *offset_b_z_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_;
|
||||
llvm::Value *offset_b_j_, *offset_b_k_;
|
||||
unsigned num_packs_0_, num_packs_1_;
|
||||
unsigned pack_size_0_, pack_size_1_;
|
||||
};
|
||||
|
@@ -71,6 +71,7 @@ public:
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_dce.run(module);
|
||||
optimize_trans.run(module);
|
||||
optimize_dce.run(module);
|
||||
}
|
||||
@@ -86,7 +87,7 @@ public:
|
||||
}
|
||||
vectorize.run(module);
|
||||
optimize_dce.run(module);
|
||||
ir::print(module, std::cout);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
|
@@ -38,8 +38,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
op();
|
||||
stream->synchronize();
|
||||
|
@@ -30,8 +30,8 @@ inline bool is_hmma(ir::value *v){
|
||||
ir::type *b_ty = b->get_type();
|
||||
// inputs have to be FP16
|
||||
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
// reduction has to be multiple of 4
|
||||
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||
// reduction has to be multiple of 4
|
||||
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -70,13 +70,13 @@ void optimize_dot::run(ir::module &mod) {
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
// if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
// std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
// std::swap(perm[0], perm[1]);
|
||||
// AA = builder.create_trans(T->get_operand(0), perm);
|
||||
// T->replace_all_uses_with(AA);
|
||||
// trans_a = true;
|
||||
// }
|
||||
}
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
|
@@ -516,6 +516,10 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
}
|
||||
}
|
||||
else {
|
||||
if(shapes.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shapes.size() >= 3;
|
||||
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_3 = builder.getInt32(3);
|
||||
@@ -526,23 +530,23 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
|
||||
unsigned fpw_2 = params_->get_param(v, "fpw.d2")->get_value();
|
||||
unsigned fpw_2 = is_batched ? params_->get_param(v, "fpw.d2")->get_value() : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
|
||||
unsigned wpt_2 = params_->get_param(v, "wpt.d2")->get_value();
|
||||
unsigned wpt_2 = is_batched ? params_->get_param(v, "wpt.d2")->get_value() : 1;
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = 1;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = hmma_wts_2 * wpt_2;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
|
||||
unsigned num_rep_2 = shapes[2]->get_value() / hmma_bts_2;
|
||||
unsigned num_rep_2 = is_batched ? shapes[2]->get_value() / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
@@ -579,19 +583,15 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
// a offset
|
||||
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
offset_a_z_ = warp_id_2;
|
||||
// b offsets
|
||||
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
offset_b_z_ = warp_id_2;
|
||||
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
@@ -617,7 +617,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
/* axes */
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
if(is_batched)
|
||||
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1062,10 +1063,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
std::map<Value*, std::vector<Value*>> fcs;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
fcs[idx[2]].push_back(TC->get_value(idx));
|
||||
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
|
||||
fcs[{builder.getInt32(0)}].push_back(TC->get_value(idx));
|
||||
});
|
||||
|
||||
|
||||
Type *fp32_ty = builder.getFloatTy();
|
||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
@@ -1121,8 +1122,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
idx_a.push_back(x.first);
|
||||
idx_b.push_back(x.first);
|
||||
// idx_a.push_back(builder.getInt32(0));
|
||||
// idx_b.push_back(builder.getInt32(0));
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
@@ -1158,9 +1159,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
result->for_each([&](indices_t idx){
|
||||
if(i >= fcs.at(idx[2]).size())
|
||||
if(i >= fcs.at({builder.getInt32(0)}).size())
|
||||
i = 0;
|
||||
result->set_value(idx, fcs.at(idx[2])[i++]);
|
||||
result->set_value(idx, fcs.at({builder.getInt32(0)})[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
||||
|
@@ -257,56 +257,54 @@ void tune::run(ir::module &mod) {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
if(node.second == 2)
|
||||
fpw->set_value(1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simplify metaparameters
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()){
|
||||
|
||||
|
||||
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
auto shapes = i->get_type()->get_tile_shapes();
|
||||
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i))
|
||||
if(i->get_type()->is_tile_ty()){
|
||||
if(auto *x = dynamic_cast<ir::load_inst*>(i)){
|
||||
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
|
||||
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
|
||||
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
|
||||
if(fragments_.at({i, 0}) == STRIDED_SCAN){
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
else{
|
||||
// for(size_t k = 2; k < shapes.size(); k++)
|
||||
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
|
||||
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
|
||||
// else
|
||||
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initialize grids
|
||||
|
||||
// for(ir::instruction *i: grids_){
|
||||
// auto shapes = i->get_type()->get_tile_shapes();
|
||||
// for(size_t k = 0; k < shapes.size(); k++)
|
||||
// if(shapes[k]->get_value() == 1) {
|
||||
// if(fragments_.at({i, k}) == STRIDED_SCAN){
|
||||
// params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
|
||||
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
|
||||
// }
|
||||
// if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
|
||||
// params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
|
||||
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
void tune::init(ir::module &mod) {
|
||||
|
@@ -62,11 +62,7 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
// params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
||||
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
||||
params_t params = {4, 2, 16, 4, 2, 16, 2, 2, 1, 1, 2, 16, 32, 16, 4, 4, 4, 4, 1}; // TT
|
||||
params_t params = heuristics();
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -72,11 +72,11 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ZS = "4";
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = ZS;
|
||||
std::string XBS0 = "TK", XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
@@ -131,9 +131,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
|
||||
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
|
||||
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
@@ -141,12 +141,17 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
bool checkb[)" + BS + R"(] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
xa = __reshape(a, )" + XAS + R"();
|
||||
xb = __reshape(b, )" + XBS + R"();
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2);
|
||||
*pc = c;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
|
@@ -255,6 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// thread pools
|
||||
// ThreadPool pool(nthreads);
|
||||
ThreadPool pool(nthreads);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
// pool.enqueue(f,values);
|
||||
f(values);
|
||||
pool.enqueue(f,values);
|
||||
// f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
@@ -174,9 +174,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << std::endl;
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << std::endl;
|
||||
passes_0.tune.init(tt_module_0);
|
||||
passes_0.tune.check_constraints(errors);
|
||||
// for(auto x: errors)
|
||||
|
Reference in New Issue
Block a user