From a1e0512703508eb3c11762e2eca7abca8896fdf7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 3 Oct 2019 14:11:50 -0400 Subject: [PATCH] [codegen] more progress --- lib/codegen/analysis/liveness.cc | 58 +++++++++++++------------------- lib/codegen/analysis/tiles.cc | 28 +++++++-------- lib/codegen/selection.cc | 22 ++++++++++-- lib/driver/module.cc | 1 + tests/bench/dot.cc | 2 +- tests/common/dot.h | 8 ++--- 6 files changed, 63 insertions(+), 56 deletions(-) diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index b1f75f03c..e23bb96da 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -91,46 +91,36 @@ void liveness::connected_components(node_t x, std::set &nodes, graph_t & bool liveness::do_pad(ir::value *x) { // alignment for matrix product -// if(auto* dot = dynamic_cast(x)) { -// auto order = tiles_->order(x); -// // a -// ir::value *a = dot->get_operand(0);\ -// size_t previous_a = pad_[a]; -// bool a_trans = dynamic_cast(a); -// bool a_row = order[0] == 0; -// if(tiles_->hmma(x) == HMMA_A_ROW) -// pad_[a] = 16; -// else if(tiles_->hmma(x) == HMMA_A_COL) -// pad_[a] = 8; -// else if(a_trans ^ a_row) -// pad_[a] = 4; -// else -// pad_[a] = 0; -// // b -// ir::value *b = dot->get_operand(1); -// size_t previous_b = pad_[b]; -// bool b_trans = dynamic_cast(b); -// bool b_col = order[0] == 0; -// if(tiles_->hmma(x) == HMMA_B_COL) -// pad_[b] = 16; -// if(tiles_->hmma(x) == HMMA_B_ROW) -// pad_[b] = 8; -// if(b_trans ^ b_col) -// pad_[b] = 4; -// else -// pad_[b] = 0; -// return previous_a != pad_[a] || previous_b != pad_[b]; -// } + if(auto* dot = dynamic_cast(x)) { + // a + ir::value *a = dot->get_operand(0);\ + size_t previous_a = pad_[a]; + if(tiles_->hmma(a) == HMMA_A_ROW) + pad_[a] = 16; + else if(tiles_->hmma(a) == HMMA_A_COL) + pad_[a] = 8; + else + pad_[a] = 0; + // b + ir::value *b = dot->get_operand(1); + size_t previous_b = pad_[b]; + if(tiles_->hmma(b) == HMMA_B_COL) + pad_[b] = 16; + if(tiles_->hmma(b) == HMMA_B_ROW) + pad_[b] = 8; + else + pad_[b] = 0; + return previous_a != pad_[a] || previous_b != pad_[b]; + } if(auto* cts = dynamic_cast(x)) { auto cts_order = tiles_->order(cts); ir::value *arg = cts->get_operand(0); auto arg_order = tiles_->order(arg); + size_t previous = pad_[cts]; if(cts_order != arg_order) pad_[cts] = 4; + return pad_[cts] != previous; } -// if(auto* tr = dynamic_cast(x)) { -// pad_[tr] = 4; -// } // padding for phi-nodes if(auto* phi = dynamic_cast(x)) { bool has_changed = false; @@ -142,7 +132,7 @@ bool liveness::do_pad(ir::value *x) { } return has_changed; } - // default -- no pading + // default -- no padding size_t previous = pad_[x]; pad_[x] = std::max(previous, 0); return pad_[x] != previous; diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index 13d3fbd13..77da5c03a 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -254,20 +254,20 @@ void tiles::run(ir::module &) { } order_[i] = order; } - for(size_t i = 0; i < num_groups; i++){ - std::vector dots; - for(ir::value* v: layout_->values(i)) - if(auto *x = dynamic_cast(v)) - dots.push_back(x); - for(ir::dot_inst* dot: dots){ - ir::value* a = dot->get_operand(0); - ir::value* b = dot->get_operand(1); - std::vector col = {0, 1}; - std::vector row = {1, 0}; - order_[layout_->id(a)] = is_trans(a) ? row : col; - order_[layout_->id(b)] = is_trans(b) ? col : row; - } - } +// for(size_t i = 0; i < num_groups; i++){ +// std::vector dots; +// for(ir::value* v: layout_->values(i)) +// if(auto *x = dynamic_cast(v)) +// dots.push_back(x); +// for(ir::dot_inst* dot: dots){ +// ir::value* a = dot->get_operand(0); +// ir::value* b = dot->get_operand(1); +// std::vector col = {0, 1}; +// std::vector row = {1, 0}; +// order_[layout_->id(a)] = is_trans(a) ? row : col; +// order_[layout_->id(b)] = is_trans(b) ? col : row; +// } +// } // tiling parameters for(auto x: largest_){ ir::value *i = x.second; diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index c355f9d2f..0b13b8982 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -1049,6 +1049,19 @@ void selection::lower_trans(ir::trans_inst *x, LLVMContext &ctx, Function *fn, I tmap_[x] = out; } +bool is_trans(ir::value *v) { + if(dynamic_cast(v)) { + return true; + } + if(auto *phi = dynamic_cast(v)) { + bool result = true; + for(ir::value *op: phi->ops()) + result = result && is_trans(op); + return result; + } + return false; +} + void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRBuilder<> &builder, distributed_tile *TC, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { @@ -1082,8 +1095,11 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn auto ord_a = tiles_->order(dot->get_operand(0)); auto ord_b = tiles_->order(dot->get_operand(1)); - bool is_a_row = ord_a[ord_a.size() - 2] == 1; - bool is_b_row = ord_b[ord_b.size() - 2] == 1; + bool is_a_trans = is_trans(dot->get_operand(0)); + bool is_b_trans = is_trans(dot->get_operand(1)); + bool is_a_row = is_a_trans ^ (ord_a[ord_a.size() - 2] == 1); + bool is_b_row = is_b_trans ^ (ord_b[ord_b.size() - 2] == 1); + if(is_a_row){ offset_a_i = builder.CreateAdd(offset_a_i, builder.CreateURem(u_thread_id, builder.getInt32(4))); @@ -1124,7 +1140,7 @@ void selection::lower_hmma_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn Value *current_offset_a_i = builder.CreateAdd(offset_a_i, builder.getInt32(pack_i*stride_rep_i*pack_size_0_)); Value *current_offset_b_i = builder.CreateAdd(offset_b_j, builder.getInt32(pack_j*stride_rep_j*pack_size_1_)); indices_t idx_a = {current_offset_a_i, builder.CreateAdd(offset_a_k, _K)}; - indices_t idx_b = {current_offset_b_i, builder.CreateAdd(offset_b_k, _K)}; + indices_t idx_b = {builder.CreateAdd(offset_b_k, _K), current_offset_b_i}; idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); Value *ha = TA->get_value(idx_a); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index e300a75f2..f29c830f4 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ +// std::cout << source << std::endl; cu_context::context_switcher ctx(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 168e239e6..a276de4b1 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -32,7 +32,7 @@ int main() { for(const auto& c: configs){ std::tie(AT, BT, M, N, K) = c; std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush; - for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K)) + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 599784570..bb27763b0 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"64", "128"}}); - opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"8"}}); - opt.num_warps = {2, 4, 8}; + opt.defines.push_back({"TM", {"128"}}); + opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TK", {"16"}}); + opt.num_warps = {4}; } // kernels