dirty but working warp-splitting
This commit is contained in:
@@ -113,9 +113,9 @@ int main() {
|
||||
// {false, false, 8192, 512, 512},
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
{false, true, 128, 128, 128},
|
||||
{false, false, 128, 128, 128},
|
||||
{true, false, 128, 128, 128},
|
||||
{true, true, 128, 128, 128}
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
// {true, true, 128, 128, 128}
|
||||
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
|
@@ -153,8 +153,8 @@ private:
|
||||
alignment_info *axis_info_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_;
|
||||
llvm::Value *offset_b_j_, *offset_b_k_;
|
||||
llvm::Value *offset_a_i_, *offset_a_k_, *offset_a_z_;
|
||||
llvm::Value *offset_b_j_, *offset_b_k_, *offset_b_z_;
|
||||
unsigned num_packs_0_, num_packs_1_;
|
||||
unsigned pack_size_0_, pack_size_1_;
|
||||
};
|
||||
|
@@ -77,7 +77,7 @@ public:
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
// reassociate.run(module);
|
||||
reassociate.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
@@ -86,7 +86,7 @@ public:
|
||||
}
|
||||
vectorize.run(module);
|
||||
optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
|
@@ -18,7 +18,7 @@ void optimize_dce::run(ir::module &mod) {
|
||||
// iterate through blocks
|
||||
for(ir::basic_block *block: rpo)
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::copy_to_shared_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|
||||
if(dynamic_cast<ir::io_inst*>(i) || dynamic_cast<ir::return_inst*>(i)
|
||||
|| dynamic_cast<ir::branch_inst*>(i) || dynamic_cast<ir::cond_branch_inst*>(i)
|
||||
|| dynamic_cast<ir::atomic_cas_inst*>(i) || dynamic_cast<ir::atomic_exch_inst*>(i) || dynamic_cast<ir::atomic_add_inst*>(i)
|
||||
|| dynamic_cast<ir::barrier_inst*>(i)){
|
||||
|
@@ -57,9 +57,27 @@ void optimize_dot::run(ir::module &mod) {
|
||||
if(trans_a){
|
||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
if(trans_b){
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
}
|
||||
|
@@ -138,8 +138,9 @@ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& sh
|
||||
Value *ld = builder.getInt32(shapes[0]);
|
||||
for(size_t i = 1; i < idx.size(); i++) {
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
|
||||
if(i < idx.size() - 1)
|
||||
if(i < idx.size() - 1){
|
||||
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -525,18 +526,23 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
|
||||
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
|
||||
unsigned fpw_2 = params_->get_param(v, "fpw.d2")->get_value();
|
||||
// warps per tile
|
||||
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
|
||||
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
|
||||
unsigned wpt_2 = params_->get_param(v, "wpt.d2")->get_value();
|
||||
// hmma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = 1;
|
||||
// hmma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = hmma_wts_2 * wpt_2;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
|
||||
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
|
||||
unsigned num_rep_2 = shapes[2]->get_value() / hmma_bts_2;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
@@ -563,7 +569,9 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder.CreateUDiv(u_warp_id, builder.getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder.CreateURem(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder.CreateUDiv(warp_id_12, builder.getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder.CreateMul(warp_id_0, builder.getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder.CreateMul(warp_id_1, builder.getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
@@ -571,9 +579,11 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
// a offset
|
||||
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
offset_a_z_ = warp_id_2;
|
||||
// b offsets
|
||||
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
|
||||
offset_b_z_ = warp_id_2;
|
||||
|
||||
|
||||
// c offsets
|
||||
@@ -598,10 +608,16 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder.CreateAdd(warp_id_2, builder.getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -851,7 +867,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||
Type *res_ty = builder.getFloatTy();
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
unsigned depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id;
|
||||
@@ -867,6 +882,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
tgt_->add_barrier(module, builder);
|
||||
builder.CreateStore(result, write_ptr);
|
||||
// build result
|
||||
unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||
// current indices
|
||||
indices_t current(write_idx.size(), builder.getInt32(0));
|
||||
@@ -999,6 +1015,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||
unsigned NK = A_shapes[red_axis]->get_value();
|
||||
|
||||
// std::cout << red_axis << " " << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
@@ -1042,10 +1059,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
TA->set_return_mode(true);
|
||||
TB->set_return_mode(true);
|
||||
|
||||
std::vector<Value *> fc;
|
||||
std::map<Value*, std::vector<Value*>> fcs;
|
||||
|
||||
result->for_each([&](indices_t idx){
|
||||
fc.push_back(TC->get_value(idx));
|
||||
fcs[idx[2]].push_back(TC->get_value(idx));
|
||||
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
|
||||
});
|
||||
|
||||
@@ -1088,53 +1105,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
unsigned stride_rep_j = wpt_1 * wts_1;
|
||||
unsigned num_rep_i = shapes[0]->get_value() / stride_rep_i;
|
||||
unsigned ld_fc = num_rep_i * 2;
|
||||
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
|
||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder.getInt32(K);
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
|
||||
|
||||
|
||||
for(auto& x: fcs){
|
||||
std::vector<Value *>& fc = x.second;
|
||||
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
|
||||
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
|
||||
for(unsigned K = 0; K < NK; K += 4){
|
||||
Value *_K = builder.getInt32(K);
|
||||
Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
|
||||
Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
|
||||
indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)};
|
||||
indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)};
|
||||
if(dot->is_a_trans())
|
||||
std::swap(idx_a[0], idx_a[1]);
|
||||
if(!dot->is_b_trans())
|
||||
std::swap(idx_b[0], idx_b[1]);
|
||||
idx_a.push_back(x.first);
|
||||
idx_b.push_back(x.first);
|
||||
Value *ha = TA->get_value(idx_a);
|
||||
Value *hb = TB->get_value(idx_b);
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++){
|
||||
Value *ha0 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *ha1 = builder.CreateBitCast(builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1)), fp16x2_ty);
|
||||
Value *hb0 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0)), fp16x2_ty);
|
||||
Value *hb1 = builder.CreateBitCast(builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1)), fp16x2_ty);
|
||||
std::vector<size_t> idx = {
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
|
||||
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
|
||||
};
|
||||
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
|
||||
fc[idx[0]] = builder.CreateExtractValue(nc, {0});
|
||||
fc[idx[1]] = builder.CreateExtractValue(nc, {1});
|
||||
fc[idx[2]] = builder.CreateExtractValue(nc, {2});
|
||||
fc[idx[3]] = builder.CreateExtractValue(nc, {3});
|
||||
fc[idx[4]] = builder.CreateExtractValue(nc, {4});
|
||||
fc[idx[5]] = builder.CreateExtractValue(nc, {5});
|
||||
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
|
||||
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write back
|
||||
unsigned i = 0;
|
||||
result->for_each([&](indices_t idx){
|
||||
result->set_value(idx, fc[i++]);
|
||||
if(i >= fcs.at(idx[2]).size())
|
||||
i = 0;
|
||||
result->set_value(idx, fcs.at(idx[2])[i++]);
|
||||
});
|
||||
|
||||
TA->set_return_mode(false);
|
||||
|
@@ -12,8 +12,10 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
unsigned shmem_allocation::is_ld_padded(ir::value *x) {
|
||||
if(dynamic_cast<ir::trans_inst*>(x))
|
||||
return 4;
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(x)){
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
}
|
||||
for(ir::user* user: x->get_users())
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
||||
bool is_hmma = params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C;
|
||||
@@ -51,7 +53,11 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
|
||||
size_t num_elements = 1;
|
||||
for(auto x: shapes)
|
||||
num_elements *= x->get_value();
|
||||
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
size_t depth;
|
||||
if(params_->get_fragment(x, 0) == tune::HMMA_FRAGMENT_C)
|
||||
depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value();
|
||||
else
|
||||
depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
|
||||
return num_elements * num_bytes * depth;
|
||||
}
|
||||
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
|
||||
|
@@ -255,6 +255,8 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
if(node.second == 2)
|
||||
fpw->set_value(1);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
|
@@ -13,7 +13,17 @@ void vectorize::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
params_->copy(rx, x);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(params_->get_param(x, "nts.d0")->get_value() == 1)
|
||||
@@ -24,6 +34,7 @@ void vectorize::run(ir::module &mod) {
|
||||
rx->set_operand(0, x);
|
||||
params_->copy(rx, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -62,11 +62,11 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
params_t params = heuristics();
|
||||
// params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
||||
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
||||
// params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT
|
||||
params_t params = {4, 2, 16, 4, 2, 16, 2, 2, 1, 1, 2, 16, 32, 16, 4, 4, 4, 4, 1}; // TT
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -72,10 +72,11 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
}
|
||||
|
||||
void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ZS = "4";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1";
|
||||
std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
@@ -100,7 +101,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
|
||||
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
|
||||
std::string XCS = "TM, TN, 1";
|
||||
std::string XCS = "TM, TN, " + ZS;
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
|
@@ -255,7 +255,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -174,9 +174,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << std::endl;
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << std::endl;
|
||||
passes_0.tune.init(tt_module_0);
|
||||
passes_0.tune.check_constraints(errors);
|
||||
// for(auto x: errors)
|
||||
|
Reference in New Issue
Block a user