diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 9d07bf227..554b3bcc3 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -38,7 +38,7 @@ inline double bench(std::function const & op, driver::stream * stream) double total_time = 0; op(); stream->synchronize(); - while(total_time*1e-9 < 1e-1){ + while(total_time*1e-9 < 1e-3){ float norm = 1; // normalize clock if possible to reduce noise in auto-tuning if(auto cu_device = dynamic_cast(stream->context()->device())) diff --git a/lib/codegen/analysis/allocation.cc b/lib/codegen/analysis/allocation.cc index 8ff77eb25..fc2a5ce22 100644 --- a/lib/codegen/analysis/allocation.cc +++ b/lib/codegen/analysis/allocation.cc @@ -20,14 +20,16 @@ unsigned allocation::is_ld_padded(ir::value *x) { if(trans->get_perm()[0]->get_value() != 0) return 4; } + auto order = tiles_->order(x); + bool is_col_major = order[0] == 0; if(tiles_->hmma(x) == HMMA_A_ROW) - return 8; + return is_col_major ? 16 : 8; if(tiles_->hmma(x) == HMMA_A_COL) - return 16; + return is_col_major ? 8 : 16; if(tiles_->hmma(x) == HMMA_B_COL) - return 8; + return is_col_major ? 16 : 8; if(tiles_->hmma(x) == HMMA_B_ROW) - return 16; + return is_col_major ? 8 : 16; if(auto* phi = dynamic_cast(x)) { unsigned result = 0; for(unsigned i = 0; i < phi->get_num_incoming(); i++) diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index c3a7c8b00..bb75df10e 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -31,7 +31,7 @@ static void cpu_ref(std::vector &c, const std::vector &a, const std::vecto for(size_t n = 0; n < N; n++){ float acc = 0; for(size_t k = 0; k < K; k++) - acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]); + acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]); c[m + n*M] = static_cast(acc); } } @@ -49,25 +49,47 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K, cpu_ref(c, a, b, M, N, K); } +template +struct to_string; -bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){ - typedef float NumericT; - std::string ty = "float"; - size_t dt_nbytes = sizeof(NumericT); +template<> struct to_string{ + static constexpr const char* value = "half"; +}; + +template<> struct to_string{ + static constexpr const char* value = "float"; +}; + +template<> struct to_string{ + static constexpr const char* value = "double"; +}; + +enum dtype_t { + FLOAT, + HALF, + DOUBLE +}; + +template +bool do_test(drv::stream* stream, bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + int32_t TM, int32_t TN, int32_t TK, size_t nwarp){ + std::string ty = to_string::value; + size_t dt_nbytes = sizeof(T); drv::context* context = stream->context(); - std::vector hc(M*N); - std::vector ha(M*K); - std::vector hb(K*N); + std::vector hc(M*N); + std::vector ha(M*K); + std::vector hb(K*N); int32_t lda = AT ? K : M; int32_t ldb = BT ? N : K; int32_t ldc = M; srand(0); for(size_t i = 0; i < ha.size(); i++) - ha[i] = static_cast((float)rand()/RAND_MAX); + ha[i] = static_cast((float)rand()/RAND_MAX); for(size_t i = 0; i < hb.size(); i++) - hb[i] = static_cast((float)rand()/RAND_MAX); + hb[i] = static_cast((float)rand()/RAND_MAX); for(size_t i = 0; i < hc.size(); i++) - hc[i] = static_cast((double)0); + hc[i] = static_cast((double)0); auto dc = std::shared_ptr(drv::buffer::create(context, hc.size()*dt_nbytes)); auto da = std::shared_ptr(drv::buffer::create(context, ha.size()*dt_nbytes)); auto db = std::shared_ptr(drv::buffer::create(context, hb.size()*dt_nbytes)); @@ -92,33 +114,47 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_ } // test stream->read(&*dc, true, 0, hc); - std::vector rc(hc.size()); + std::vector rc(hc.size()); cpu_ref(AT, BT, M, N, K, rc, ha, hb); return testing::diff(hc, rc); } +bool do_test(triton::driver::stream *stream, + dtype_t dtype, bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + int32_t TM, int32_t TN, int32_t TK, size_t nwarp) { + switch(dtype){ + case HALF: return do_test(stream, AT, BT, M, N, K, TM, TN, TK, nwarp); + case FLOAT: return do_test(stream, AT, BT, M, N, K, TM, TN, TK, nwarp); + case DOUBLE: return do_test(stream, AT, BT, M, N, K, TM, TN, TK, nwarp); + default: break; + } + return false; +} + int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark - typedef std::tuple config_t; + typedef std::tuple config_t; std::vector configs; for(bool AT: std::array{false, true}) for(bool BT: std::array{false, true}) - for(int TM: std::vector{16, 128}) - for(int TN: std::vector{16, 128}) + for(int TM: std::vector{32, 64}) + for(int TN: std::vector{32, 64}) for(int TK: std::vector{16, 32}) for(int nwarps: std::vector{1, 2, 4, 8}){ - configs.push_back(config_t{AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); + configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps}); } // does the work + dtype_t dtype; bool AT, BT; int M, N, K, TM, TN, TK, nwarp; for(const auto& c: configs){ - std::tie(AT, BT, M, N, K, TM, TN, TK, nwarp) = c; + std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; std::cout << "Testing " << c << " ... " << std::flush; - if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp)) + if(do_test(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp)) std::cout << " Pass! " << std::endl; else{ std::cout << " Fail! " << std::endl;