#include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "dot.h" #include "util.h" int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to test typedef std::tuple config_t; std::vector configs; for(int TM: std::vector{32, 64, 128}) for(int TN: std::vector{32, 64, 128}) for(int TK: std::vector{16}) for(int nwarps: std::vector{4}) for(bool AT: std::array{false, true}) for(bool BT: std::array{false, true}){ configs.push_back(config_t{HALF, AT, BT, TM, TN, TK, TM, TN, TK, nwarps}); } // test dtype_t dtype; bool AT, BT; int M, N, K, TM, TN, TK, nwarp; for(const auto& c: configs){ std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; std::cout << "Testing " << c << " ... " << std::flush; if(test_dot(stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) std::cout << " Pass! " << std::endl; else{ std::cout << " Fail! " << std::endl; } } }