dirty but working warp-splitting

This commit is contained in:
Philippe Tillet
2019-08-06 21:07:13 -07:00
parent 494bfa7671
commit 7b75b68edc
13 changed files with 132 additions and 69 deletions

View File

@@ -113,9 +113,9 @@ int main() {
// {false, false, 8192, 512, 512},
// {false, true, 8192, 8192, 8192}
{false, true, 128, 128, 128},
{false, false, 128, 128, 128},
{true, false, 128, 128, 128},
{true, true, 128, 128, 128}
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},
// {true, true, 128, 128, 128}
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},

View File

@@ -153,8 +153,8 @@ private:
alignment_info *axis_info_;
std::map<unsigned, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
llvm::Value *offset_a_i_, *offset_a_k_;
llvm::Value *offset_b_j_, *offset_b_k_;
llvm::Value *offset_a_i_, *offset_a_k_, *offset_a_z_;
llvm::Value *offset_b_j_, *offset_b_k_, *offset_b_z_;
unsigned num_packs_0_, num_packs_1_;
unsigned pack_size_0_, pack_size_1_;
};

View File

@@ -77,7 +77,7 @@ public:
void target_dependent(ir::module &module) {
alignment_info.run(module);
// reassociate.run(module);
reassociate.run(module);
if(target_->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
@@ -86,7 +86,7 @@ public:
}
vectorize.run(module);
optimize_dce.run(module);
// ir::print(module, std::cout);
ir::print(module, std::cout);
}
codegen::tune tune;

View File

@@ -18,7 +18,7 @@ void optimize_dce::run(ir::module &mod) {
// iterate through blocks
for(ir::basic_block *block: rpo)
for(ir::instruction *i: block->get_inst_list()){
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i)
|| dynamic_cast<ir::barrier_inst*>(i)){

View File

@@ -57,9 +57,27 @@ void optimize_dot::run(ir::module &mod) {
if(trans_a){
AA = ((ir::trans_inst*)A)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
if(trans_b){
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
}
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);
}

View File

@@ -138,8 +138,9 @@ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& sh
Value *ld = builder.getInt32(shapes[0]);
for(size_t i = 1; i < idx.size(); i++) {
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
if(i < idx.size() - 1)
if(i < idx.size() - 1){
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
}
}
return result;
}
@@ -525,18 +526,23 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
// fragments per warp
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
unsigned fpw_2 = params_->get_param(v, "fpw.d2")->get_value();
// warps per tile
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
unsigned wpt_2 = params_->get_param(v, "wpt.d2")->get_value();
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = hmma_wts_2 * wpt_2;
// number of repetition
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
unsigned num_rep_2 = shapes[2]->get_value() / hmma_bts_2;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
@@ -563,7 +569,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
/* inter warp offset */
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_1 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
@@ -571,9 +579,11 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
// a offset
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
offset_a_z_ = warp_id_2;
// b offsets
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
offset_b_z_ = warp_id_2;
// c offsets
@@ -598,10 +608,16 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
// z indices
std::vector<Value*> idx_z;
for(unsigned pack = 0; pack < num_rep_2; pack++)
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
/* axes */
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
}
}
@@ -851,7 +867,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *res_ty = builder.getFloatTy();
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
unsigned depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
for(auto& x: partial) {
// current element being computed
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
@@ -867,6 +882,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
for(unsigned i = depth/2; i > 0; i >>= 1){
// current indices
indices_t current(write_idx.size(), builder.getInt32(0));
@@ -999,6 +1015,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = dot->is_a_trans() ? 0 : 1;
unsigned NK = A_shapes[red_axis]->get_value();
// std::cout << red_axis << " " << NK << std::endl;
if(NK != 1)
{
@@ -1042,10 +1059,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
TA->set_return_mode(true);
TB->set_return_mode(true);
std::vector<Value *> fc;
std::map<Value*, std::vector<Value*>> fcs;
result->for_each([&](indices_t idx){
fc.push_back(TC->get_value(idx));
fcs[idx[2]].push_back(TC->get_value(idx));
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
});
@@ -1088,53 +1105,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned stride_rep_j = wpt_1 * wts_1;
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
unsigned ld_fc = num_rep_i * 2;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
for(auto& x: fcs){
std::vector<Value *>& fc = x.second;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
if(dot->is_a_trans())
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
idx_a.push_back(x.first);
idx_b.push_back(x.first);
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
}
}
}
}
}
// write back
unsigned i = 0;
result->for_each([&](indices_t idx){
result->set_value(idx, fc[i++]);
if(i >= fcs.at(idx[2]).size())
i = 0;
result->set_value(idx, fcs.at(idx[2])[i++]);
});
TA->set_return_mode(false);

View File

@@ -12,8 +12,10 @@ namespace triton{
namespace codegen{
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
if(dynamic_cast<ir::trans_inst*>(x))
return 4;
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
if(trans->get_perm()[0]->get_value() != 0)
return 4;
}
for(ir::user* user: x->get_users())
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
bool is_hmma = params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C;
@@ -51,7 +53,11 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
size_t num_elements = 1;
for(auto x: shapes)
num_elements *= x->get_value();
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
size_t depth;
if(params_->get_fragment(x, 0) == tune::HMMA_FRAGMENT_C)
depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
else
depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;

View File

@@ -255,6 +255,8 @@ void tune::run(ir::module &mod) {
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
if(node.second == 2)
fpw->set_value(1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}

View File

@@ -13,7 +13,17 @@ void vectorize::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
if(auto *trans = dynamic_cast<ir::trans_inst*>(i)){
ir::value *x = i->get_operand(0);
if(trans->get_perm()[0]->get_value() != 0)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
params_->copy(rx, x);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
if(params_->get_param(x, "nts.d0")->get_value() == 1)
@@ -24,6 +34,7 @@ void vectorize::run(ir::module &mod) {
rx->set_operand(0, x);
params_->copy(rx, x);
}
}
}
}

View File

@@ -62,11 +62,11 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else{
params_t params = heuristics();
// params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
// params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT
params_t params = {4, 2, 16, 4, 2, 16, 2, 2, 1, 1, 2, 16, 32, 16, 4, 4, 4, 4, 1}; // TT
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -72,10 +72,11 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
}
void dot::triton_c_src(std::ostream &os) const {
std::string ZS = "4";
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1";
std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN";
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
@@ -100,7 +101,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, 1";
std::string XCS = "TM, TN, " + ZS;
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =

View File

@@ -255,7 +255,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -174,9 +174,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
std::lock_guard<std::mutex> lock(mutex);
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << std::endl;
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << std::endl;
passes_0.tune.init(tt_module_0);
passes_0.tune.check_constraints(errors);
// for(auto x: errors)