diff --git a/include/triton/codegen/analysis/tiles.h b/include/triton/codegen/analysis/tiles.h index 6fe964738..ca1eb0e90 100644 --- a/include/triton/codegen/analysis/tiles.h +++ b/include/triton/codegen/analysis/tiles.h @@ -27,11 +27,7 @@ class align; enum layout_t { SCANLINE, - HMMA_C, - HMMA_A_COL, - HMMA_A_ROW, - HMMA_B_COL, - HMMA_B_ROW + HMMA_C }; class tiles { diff --git a/lib/codegen/analysis/liveness.cc b/lib/codegen/analysis/liveness.cc index e23bb96da..297f31d92 100644 --- a/lib/codegen/analysis/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -89,28 +89,43 @@ void liveness::connected_components(node_t x, std::set &nodes, graph_t & } } +bool is_trans(ir::value *v) { + if(dynamic_cast(v)) { + return true; + } + if(auto *phi = dynamic_cast(v)) { + bool result = true; + for(ir::value *op: phi->ops()) + result = result && is_trans(op); + return result; + } + return false; +} + + bool liveness::do_pad(ir::value *x) { // alignment for matrix product if(auto* dot = dynamic_cast(x)) { // a - ir::value *a = dot->get_operand(0);\ - size_t previous_a = pad_[a]; - if(tiles_->hmma(a) == HMMA_A_ROW) - pad_[a] = 16; - else if(tiles_->hmma(a) == HMMA_A_COL) - pad_[a] = 8; - else - pad_[a] = 0; - // b + ir::value *a = dot->get_operand(0); ir::value *b = dot->get_operand(1); - size_t previous_b = pad_[b]; - if(tiles_->hmma(b) == HMMA_B_COL) - pad_[b] = 16; - if(tiles_->hmma(b) == HMMA_B_ROW) - pad_[b] = 8; - else - pad_[b] = 0; - return previous_a != pad_[a] || previous_b != pad_[b]; + size_t a_previous = pad_[a]; + size_t b_previous = pad_[b]; + auto a_order = tiles_->order(a); + auto b_order = tiles_->order(b); + bool a_row = is_trans(a) ^ (a_order[0] == 1); + bool b_row = is_trans(b) ^ (b_order[0] == 1); + auto a_shapes = a->get_type()->get_tile_shapes(); + auto b_shapes = b->get_type()->get_tile_shapes(); + pad_[a] = std::max(pad_[a], (24 - a_shapes[a_row ? 0 : 1]) % 32); + pad_[b] = std::max(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32); + return a_previous != pad_[a] || b_previous != pad_[b]; + } + if(auto* trans = dynamic_cast(x)) { + ir::value *op = trans->get_operand(0); + size_t previous = pad_[op]; + pad_[op] = std::max(pad_[op], pad_[x]); + return previous != pad_[op]; } if(auto* cts = dynamic_cast(x)) { auto cts_order = tiles_->order(cts); @@ -118,7 +133,7 @@ bool liveness::do_pad(ir::value *x) { auto arg_order = tiles_->order(arg); size_t previous = pad_[cts]; if(cts_order != arg_order) - pad_[cts] = 4; + pad_[cts] = std::max(pad_[cts], 4); return pad_[cts] != previous; } // padding for phi-nodes diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index 77da5c03a..3d414f723 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -215,15 +215,7 @@ void tiles::run(ir::module &) { for(size_t i = 0; i < num_groups; i++) { const auto& values = layout_->values(i); bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c); - bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col); - bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row); - bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col); - bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row); if(hmma_c) hmma_[i] = HMMA_C; - else if(hmma_a_col) hmma_[i] = HMMA_A_COL; - else if(hmma_a_row) hmma_[i] = HMMA_A_ROW; - else if(hmma_b_col) hmma_[i] = HMMA_B_COL; - else if(hmma_b_row) hmma_[i] = HMMA_B_ROW; else hmma_[i] = SCANLINE; } @@ -254,20 +246,33 @@ void tiles::run(ir::module &) { } order_[i] = order; } -// for(size_t i = 0; i < num_groups; i++){ -// std::vector dots; -// for(ir::value* v: layout_->values(i)) -// if(auto *x = dynamic_cast(v)) -// dots.push_back(x); -// for(ir::dot_inst* dot: dots){ -// ir::value* a = dot->get_operand(0); -// ir::value* b = dot->get_operand(1); -// std::vector col = {0, 1}; -// std::vector row = {1, 0}; -// order_[layout_->id(a)] = is_trans(a) ? row : col; -// order_[layout_->id(b)] = is_trans(b) ? col : row; -// } -// } + // matrix multiplication optimizations + for(size_t i = 0; i < num_groups; i++){ + std::vector dots; + for(ir::value* v: layout_->values(i)) + if(auto *x = dynamic_cast(v)) + dots.push_back(x); + for(ir::dot_inst* dot: dots){ + ir::value* a = dot->get_operand(0); + ir::value* b = dot->get_operand(1); + if(hmma_.at(layout_->id(dot)) == HMMA_C){ + auto a_val = layout_->values(layout_->id(a)); + auto b_val = layout_->values(layout_->id(b)); + for(ir::value *v: a_val) + if(auto *cts = dynamic_cast(v)) + order_[layout_->id(a)] = order_[layout_->id(cts->get_operand(0))]; + for(ir::value *v: b_val) + if(auto *cts = dynamic_cast(v)) + order_[layout_->id(b)] = order_[layout_->id(cts->get_operand(0))]; + } + else{ + std::vector col = {0, 1}; + std::vector row = {1, 0}; + order_[layout_->id(a)] = is_trans(a) ? row : col; + order_[layout_->id(b)] = is_trans(b) ? col : row; + } + } + } // tiling parameters for(auto x: largest_){ ir::value *i = x.second; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index e9f5f8921..19c55a0a1 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -239,7 +239,9 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c axes.run(module); layouts.run(module); align.run(module); +// ir::print(module, std::cout); tiles.run(module); +// ir::print(module, std::cout); selection.run(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index a276de4b1..927f0044b 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -7,32 +7,34 @@ int main() { auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark - typedef std::tuple config_t; + typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; + for(auto ord: std::vector>{{0, 1}, {1, 0}}) for(auto x: std::vector>{{false, false}, {false, true}, {true, false}, {true, true}}){ std::vector tmp = { - config_t{x[0], x[1], 2048, 2048, 2048}, -// config_t{x[0], x[1], 16, 2048, 2048}, -// config_t{x[0], x[1], 32, 2048, 2048}, -// config_t{x[0], x[1], 64, 2048, 2048}, -// config_t{x[0], x[1], 128, 2048, 2048}, -// config_t{x[0], x[1], 7000, 2048, 2048}, -// config_t{x[0], x[1], 16, 4096, 4096}, -// config_t{x[0], x[1], 32, 4096, 4096}, -// config_t{x[0], x[1], 64, 4096, 4096}, -// config_t{x[0], x[1], 128, 4096, 4096}, -// config_t{x[0], x[1], 7000, 4096, 4096} + config_t{ord, x[0], x[1], 2048, 2048, 2048}, +// config_t{ord, x[0], x[1], 16, 2048, 2048}, +// config_t{ord, x[0], x[1], 32, 2048, 2048}, +// config_t{ord, x[0], x[1], 64, 2048, 2048}, +// config_t{ord, x[0], x[1], 128, 2048, 2048}, +// config_t{ord, x[0], x[1], 7000, 2048, 2048}, +// config_t{ord, x[0], x[1], 16, 4096, 4096}, +// config_t{ord, x[0], x[1], 32, 4096, 4096}, +// config_t{ord, x[0], x[1], 64, 4096, 4096}, +// config_t{ord, x[0], x[1], 128, 4096, 4096}, +// config_t{ord, x[0], x[1], 7000, 4096, 4096} }; configs.insert(configs.end(), tmp.begin(), tmp.end()); } // does the work + std::vector ord; bool AT, BT; int32_t M, N, K; for(const auto& c: configs){ - std::tie(AT, BT, M, N, K) = c; - std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K)) + std::tie(ord, AT, BT, M, N, K) = c; + std::cout << "// " << c << std::flush; + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index bb27763b0..00d605f5d 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -19,7 +19,7 @@ static void cc_dot(std::vector &c, const std::vector &a, const std::vector for(size_t n = 0; n < N; n++){ float acc = 0; for(size_t k = 0; k < K; k++) - acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]); + acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]); c[m + n*M] = static_cast(acc); } } @@ -67,6 +67,7 @@ template bool triton_dot(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp, + const std::vector& a_order, const std::vector& b_order, run_mode_t mode, std::vector& bench, bool &test){ std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); @@ -74,6 +75,8 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, int32_t lda = AT ? K : M; int32_t ldb = BT ? N : K; int32_t ldc = M; + std::vector sa = { "1", "lda" }; + std::vector sb = { "1", "ldb" }; // inputs auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); @@ -82,20 +85,20 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, // macros rt::function::options_space_t opt; - // B access patterns - opt.defines.push_back({"USEB", {BT? "^b" : "b" }}); - opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }}); - opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }}); - opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }}); - opt.defines.push_back({"STRIDE_BK", {BT? "1" : "ldb" }}); - opt.defines.push_back({"STRIDE_BN", {BT? "ldb" : "1" }}); // A access patterns - opt.defines.push_back({"USEA", {AT? "^a" : "a" }}); - opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }}); - opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }}); - opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }}); - opt.defines.push_back({"STRIDE_AK", {AT? "lda" : "1" }}); - opt.defines.push_back({"STRIDE_AM", {AT? "1" : "lda" }}); + opt.defines.push_back({"USEA", {AT? "^a" : "a" }}); + opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }}); + opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }}); + opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }}); + opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); + opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); + // B access patterns + opt.defines.push_back({"USEB", {BT? "^b" : "b" }}); + opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }}); + opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }}); + opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }}); + opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); + opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); // data-type opt.defines.push_back({"TYPE", {ty}}); // tile sizes @@ -164,13 +167,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, std::vector bench_dot(drv::stream* stream, dtype_t dtype, bool AT, bool BT, - int32_t M, int32_t N, int32_t K) { + int32_t M, int32_t N, int32_t K, + const std::vector& a_order, const std::vector& b_order) { std::vector bench; bool test; switch(dtype){ - case HALF: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break; - case FLOAT: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break; - case DOUBLE: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break; + case HALF: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; + case FLOAT: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; + case DOUBLE: triton_dot(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break; default: break; } return bench; @@ -178,13 +182,14 @@ std::vector bench_dot(drv::stream* stream, bool test_dot(drv::stream* stream, dtype_t dtype, bool AT, bool BT, int32_t M, int32_t N, int32_t K, + const std::vector& a_order, const std::vector& b_order, int32_t TM, int32_t TN, int32_t TK, size_t nwarp) { std::vector bench; bool test = false; switch(dtype){ - case HALF: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break; - case FLOAT: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break; - case DOUBLE: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break; + case HALF: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; + case FLOAT: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; + case DOUBLE: triton_dot(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break; default: break; } return test; diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 53fbc990d..59b556858 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -25,7 +25,7 @@ int main() { for(const auto& c: configs){ std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; std::cout << "Testing " << c << " ... " << std::flush; - if(test_dot(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp)) + if(test_dot(stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) std::cout << " Pass! " << std::endl; else{ std::cout << " Fail! " << std::endl;