diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 6d4c728a8..0a64729e9 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -26,8 +26,8 @@ struct perf_t { perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ - typedef half NumericT; - std::string ty = "half"; + typedef float NumericT; + std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); triton::driver::context* context = stream->context(); std::vector hc(M*N); @@ -112,7 +112,11 @@ int main() { std::vector configs = { // {false, false, 8192, 512, 512}, // {false, true, 8192, 8192, 8192} - {false, true, 128, 128, 128}, +// {false, true, 128, 128, 128}, +// {false, false, 128, 128, 128}, +// {true, false, 128, 128, 128}, + {true, true, 128, 128, 128} + // {false, true, 32768, 256, 512} // {true, false, 8192, 512, 512}, // {true, true, 8192, 512, 512} diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 8c482b0b6..394084395 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -62,11 +62,11 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v jit->add_module(name_.c_str(), src.c_str(), best.params); } else{ - params_t params = heuristics(); +// params_t params = heuristics(); // params_t params = jit->get_valid(name_.c_str(), src.c_str()); // params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT // params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN -// params_t params = {4, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 1, 32, 16, 4, 4, 4, 4, 4, 1}; // TT + params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str());