diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 87bb739e2..3f04d01ad 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -6,6 +6,7 @@ #include "triton/driver/stream.h" #include "triton/dnn/dot.h" #include "triton/tools/bench.hpp" +#include "triton/external/half.hpp" #include "cuda.h" template @@ -25,7 +26,7 @@ struct perf_t { perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ - typedef float NumericT; + typedef half NumericT; std::string ty = "half"; size_t dt_nbytes = sizeof(NumericT); triton::driver::context* context = stream->context(); @@ -34,11 +35,11 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int std::vector hb(K*N); srand(0); for(size_t i = 0; i < ha.size(); i++) - ha[i] = (NumericT)rand()/RAND_MAX; + ha[i] = static_cast((double)rand()/RAND_MAX); for(size_t i = 0; i < hb.size(); i++) - hb[i] = (NumericT)rand()/RAND_MAX; + hb[i] = static_cast((double)rand()/RAND_MAX); for(size_t i = 0; i < hc.size(); i++) - hc[i] = 0; + hc[i] = static_cast((double)0); triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes); triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes); triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes); @@ -48,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); // benchmark cublas // NumericT alpha = 1; // NumericT beta = 0; @@ -73,10 +74,10 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int // test stream->read(dc, true, 0, hc); - std::vector rc(hc.size()); + std::vector rc(hc.size()); dot.cpu_ref(rc, ha, hb); for(size_t i = 0; i < M*N; i++) - if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); } @@ -111,7 +112,7 @@ int main() { std::vector configs = { // {false, false, 8192, 512, 512}, // {false, true, 8192, 8192, 8192} - {true, true, 128, 128, 128}, + {false, true, 128, 128, 128}, // {false, true, 32768, 256, 512} // {true, false, 8192, 512, 512}, // {true, true, 8192, 512, 512} diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h index c655d12b5..2beeede7b 100644 --- a/include/triton/dnn/dot.h +++ b/include/triton/dnn/dot.h @@ -42,9 +42,9 @@ public: size_t M, size_t N, size_t K){ for(size_t m = 0; m < M; m++) for(size_t n = 0; n < N; n++){ - T acc = 0; + T acc = static_cast((double)0); for(size_t k = 0; k < K; k++) - acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]); + acc = acc + (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]); c[m + n*M] = acc; } } diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index fffec7794..ae227b135 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -73,7 +73,6 @@ public: optimize_dot.run(module); optimize_trans.run(module); optimize_dce.run(module); -// ir::print(module, std::cout); } void target_dependent(ir::module &module) { diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index bc4c7118d..b05f7e79e 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -194,7 +194,6 @@ std::vector tune::get_params(ir::module &mod) { for(ir::instruction *i : block->get_inst_list()) for(auto &x: params_[i]) if(seen.insert(x.second).second && !x.second->has_value()){ -// std::cout << i->get_name() << " " << x.first << std::endl; result.push_back(x.second); } @@ -291,28 +290,29 @@ void tune::run(ir::module &mod) { } // initialize grids + +// for(ir::instruction *i: grids_){ +// auto shapes = i->get_type()->get_tile_shapes(); +// for(size_t k = 0; k < shapes.size(); k++) +// if(shapes[k]->get_value() == 1) { +// if(fragments_.at({i, k}) == STRIDED_SCAN){ +// params_.at(i).at("nts.d" + std::to_string(k))->set_value(1); +// params_.at(i).at("mts.d" + std::to_string(k))->set_value(1); +// } +// if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){ +// params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1); +// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1); +// } +// } +// } +} + +void tune::init(ir::module &mod) { for(ir::function *fn: mod.get_function_list()){ std::map references; create_grids(grids_, references, fn); } - for(ir::instruction *i: grids_){ - auto shapes = i->get_type()->get_tile_shapes(); - for(size_t k = 0; k < shapes.size(); k++) - if(shapes[k]->get_value() == 1) { - if(fragments_.at({i, k}) == STRIDED_SCAN){ - params_.at(i).at("nts.d" + std::to_string(k))->set_value(1); - params_.at(i).at("mts.d" + std::to_string(k))->set_value(1); - } - if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){ - params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1); - params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1); - } - } - } -} - -void tune::init(ir::module &mod) { num_threads_ = get_req_num_threads(grids_.front()); } @@ -407,7 +407,9 @@ bool tune::check_constraints(std::map> &er else { ir::metaparameter *fpw = params_[i]["fpw.d" + strk]; ir::metaparameter *wpt = params_[i]["wpt.d" + strk]; - multiple = fpw->get_value()*wpt->get_value()*8; + multiple = fpw->get_value()*wpt->get_value(); + if(k < 2) + multiple *= 8; } if(shapes[k]->get_value() % multiple != 0) errors[i].push_back("for dim " + strk + ": shape (" + to_string(shapes[k]->get_value()) + ")" diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 1ad741240..8c482b0b6 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -62,11 +62,11 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v jit->add_module(name_.c_str(), src.c_str(), best.params); } else{ -// params_t params = heuristics(); + params_t params = heuristics(); // params_t params = jit->get_valid(name_.c_str(), src.c_str()); // params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT // params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN - params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT +// params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 83798921a..65395695c 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -74,8 +74,8 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; - std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1"; - std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN"; + std::string XAS0 = "TM", XAS1 = "TK", XAS2 = "1"; + std::string XBS0 = "TK", XBS1 = "1", XBS2 = "TN"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; @@ -105,11 +105,12 @@ void dot::triton_c_src(std::ostream &os) const { std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int TM = {32}; -const tunable int TN = {32}; +const tunable int TM = {16, 32, 64, 128}; +const tunable int TN = {16, 32, 64, 128}; const tunable int TK = {32}; const tunable int GZ = {1}; + void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, restrict read_only align(16) float *C, diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 1f6a60ccd..86102a460 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools - ThreadPool pool(nthreads); +// ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function - pool.enqueue(f,values); -// f(values); +// pool.enqueue(f,values); + f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return;