#pragma once #ifndef _TRITON_TESTS_UTIL_H #define _TRITON_TESTS_UTIL_H #include #include #include "triton/runtime/function.h" namespace drv = triton::driver; namespace rt = triton::runtime; /* ------------------------ * Launch Grid * ------------------------ */ inline size_t ceil(size_t x, size_t y) { return (x + y - 1) / y; } inline rt::function::grid_fn_ty grid1d(size_t N) { return [N](const rt::function::options_t& x) { return rt::grid_t{ceil(N, x.D("TN"))}; }; } inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { return [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN"))}; }; } inline rt::function::grid_fn_ty grid_nd(const std::vector &shape, const std::vector& ts) { return [&shape, &ts](const rt::function::options_t& x) { rt::grid_t ret; for(size_t d = 0; d < shape.size(); d++) ret.push_back(ceil(shape[d], x.D(ts[d]))); return ret; }; } inline std::vector> tile_nd(size_t rank) { assert(rank <= 3); if(rank == 1) return {{"128", "256", "512", "1024"}}; if(rank == 2) return {{"16", "32", "64"}, {"16", "32", "64"}}; if(rank == 3) return {{"4", "16", "32"}, {"4", "16", "32"}, {"4", "16", "32"}}; return {}; } /* ------------------------ * Tensor Initialization * ------------------------ */ template void init_rand(std::vector& x) { for(size_t i = 0; i < x.size(); i++) x[i] = i; } template void init_zeros(std::vector& x) { for(size_t i = 0; i < x.size(); i++) x[i] = 0; } /* ------------------------ * Loop Nests * ------------------------ */ void _loop_nest(std::vector const & ranges, std::function const &)> const & f){ int D = ranges.size(); std::vector values(D, 0); // Start with innermost loop int i = D - 1; while(true){ // Execute function f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return; values[i--] = 0; } i = D - 1; } } /* ----------------------- * TENSOR INDEXING * ----------------------- */ enum order_t { ROWMAJOR, COLMAJOR }; int offset(const std::vector& idx, const std::vector& shapes) { int result = idx[0]; int ld = 1; for(int i = 1; i < idx.size(); i++){ ld *= shapes[i - 1]; result += idx[i]*ld; } return result; } /* ----------------------- * REDUCTION HELPERS * ----------------------- */ enum reduce_op_t { ADD, MAX, MIN }; std::string to_str(reduce_op_t op) { switch (op) { case ADD: return "+"; case MAX: return "max"; case MIN: return "min"; default: break; } assert(false); return ""; } template std::function get_accumulator(reduce_op_t op) { switch (op) { case ADD: return [](T x, T y) { return x + y; }; case MAX: return [](T x, T y) { return std::max(x, y); }; case MIN: return [](T x, T y) { return std::min(x, y); }; default: break; } assert(false); return std::function(); } /* ----------------------- * TENSOR COMPARISON * ----------------------- */ namespace testing { template bool diff(const std::vector& hc, const std::vector& rc) { if(hc.size() != rc.size()) return false; for(size_t i = 0; i < hc.size(); i++) if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; return false; } return true; } } /* ----------------------- * PRETTY PRINTING * ----------------------- */ namespace aux{ template struct seq{}; template struct gen_seq : gen_seq{}; template struct gen_seq<0, Is...> : seq{}; template void print_tuple(std::basic_ostream& os, Tuple const& t, seq){ using swallow = int[]; (void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get(t)), 0)...}; } } // aux:: template auto operator<<(std::basic_ostream& os, std::tuple const& t) -> std::basic_ostream& { aux::print_tuple(os, t, aux::gen_seq()); return os; } template std::basic_ostream& operator<<(std::basic_ostream& os, const std::vector& vec) { os << "{"; for(size_t i = 0; i < vec.size(); i++){ if(i > 0) os << ", "; os << vec[i]; } os << "}"; return os; } template std::basic_ostream& operator<<(std::basic_ostream& os, reduce_op_t op) { return os << to_str(op); } #endif