[tests] [dot] now testing row-major
This commit is contained in:
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
|||||||
double total_time = 0;
|
double total_time = 0;
|
||||||
op();
|
op();
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
while(total_time*1e-9 < 1e-1){
|
while(total_time*1e-9 < 1e-3){
|
||||||
float norm = 1;
|
float norm = 1;
|
||||||
// normalize clock if possible to reduce noise in auto-tuning
|
// normalize clock if possible to reduce noise in auto-tuning
|
||||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||||
|
@@ -20,14 +20,16 @@ unsigned allocation::is_ld_padded(ir::value *x) {
|
|||||||
if(trans->get_perm()[0]->get_value() != 0)
|
if(trans->get_perm()[0]->get_value() != 0)
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
|
auto order = tiles_->order(x);
|
||||||
|
bool is_col_major = order[0] == 0;
|
||||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||||
return 8;
|
return is_col_major ? 16 : 8;
|
||||||
if(tiles_->hmma(x) == HMMA_A_COL)
|
if(tiles_->hmma(x) == HMMA_A_COL)
|
||||||
return 16;
|
return is_col_major ? 8 : 16;
|
||||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||||
return 8;
|
return is_col_major ? 16 : 8;
|
||||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||||
return 16;
|
return is_col_major ? 8 : 16;
|
||||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||||
unsigned result = 0;
|
unsigned result = 0;
|
||||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||||
|
@@ -31,7 +31,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
|
|||||||
for(size_t n = 0; n < N; n++){
|
for(size_t n = 0; n < N; n++){
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
for(size_t k = 0; k < K; k++)
|
for(size_t k = 0; k < K; k++)
|
||||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]);
|
||||||
c[m + n*M] = static_cast<T>(acc);
|
c[m + n*M] = static_cast<T>(acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,25 +49,47 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
|||||||
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
struct to_string;
|
||||||
|
|
||||||
bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
template<> struct to_string<half_float::half>{
|
||||||
typedef float NumericT;
|
static constexpr const char* value = "half";
|
||||||
std::string ty = "float";
|
};
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
|
||||||
|
template<> struct to_string<float>{
|
||||||
|
static constexpr const char* value = "float";
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct to_string<double>{
|
||||||
|
static constexpr const char* value = "double";
|
||||||
|
};
|
||||||
|
|
||||||
|
enum dtype_t {
|
||||||
|
FLOAT,
|
||||||
|
HALF,
|
||||||
|
DOUBLE
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
bool do_test(drv::stream* stream, bool AT, bool BT,
|
||||||
|
int32_t M, int32_t N, int32_t K,
|
||||||
|
int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
||||||
|
std::string ty = to_string<T>::value;
|
||||||
|
size_t dt_nbytes = sizeof(T);
|
||||||
drv::context* context = stream->context();
|
drv::context* context = stream->context();
|
||||||
std::vector<NumericT> hc(M*N);
|
std::vector<T> hc(M*N);
|
||||||
std::vector<NumericT> ha(M*K);
|
std::vector<T> ha(M*K);
|
||||||
std::vector<NumericT> hb(K*N);
|
std::vector<T> hb(K*N);
|
||||||
int32_t lda = AT ? K : M;
|
int32_t lda = AT ? K : M;
|
||||||
int32_t ldb = BT ? N : K;
|
int32_t ldb = BT ? N : K;
|
||||||
int32_t ldc = M;
|
int32_t ldc = M;
|
||||||
srand(0);
|
srand(0);
|
||||||
for(size_t i = 0; i < ha.size(); i++)
|
for(size_t i = 0; i < ha.size(); i++)
|
||||||
ha[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
ha[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||||
for(size_t i = 0; i < hb.size(); i++)
|
for(size_t i = 0; i < hb.size(); i++)
|
||||||
hb[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
|
hb[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||||
for(size_t i = 0; i < hc.size(); i++)
|
for(size_t i = 0; i < hc.size(); i++)
|
||||||
hc[i] = static_cast<NumericT>((double)0);
|
hc[i] = static_cast<T>((double)0);
|
||||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
||||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
||||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
||||||
@@ -92,33 +114,47 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_
|
|||||||
}
|
}
|
||||||
// test
|
// test
|
||||||
stream->read(&*dc, true, 0, hc);
|
stream->read(&*dc, true, 0, hc);
|
||||||
std::vector<NumericT> rc(hc.size());
|
std::vector<T> rc(hc.size());
|
||||||
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||||
return testing::diff(hc, rc);
|
return testing::diff(hc, rc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool do_test(triton::driver::stream *stream,
|
||||||
|
dtype_t dtype, bool AT, bool BT,
|
||||||
|
int32_t M, int32_t N, int32_t K,
|
||||||
|
int32_t TM, int32_t TN, int32_t TK, size_t nwarp) {
|
||||||
|
switch(dtype){
|
||||||
|
case HALF: return do_test<half_float::half>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||||
|
case FLOAT: return do_test<float>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||||
|
case DOUBLE: return do_test<double>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<bool, bool, int, int, int, int, int, int, int> config_t;
|
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(bool AT: std::array<bool, 2>{false, true})
|
for(bool AT: std::array<bool, 2>{false, true})
|
||||||
for(bool BT: std::array<bool, 2>{false, true})
|
for(bool BT: std::array<bool, 2>{false, true})
|
||||||
for(int TM: std::vector<int>{16, 128})
|
for(int TM: std::vector<int>{32, 64})
|
||||||
for(int TN: std::vector<int>{16, 128})
|
for(int TN: std::vector<int>{32, 64})
|
||||||
for(int TK: std::vector<int>{16, 32})
|
for(int TK: std::vector<int>{16, 32})
|
||||||
for(int nwarps: std::vector<int>{1, 2, 4, 8}){
|
for(int nwarps: std::vector<int>{1, 2, 4, 8}){
|
||||||
configs.push_back(config_t{AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
configs.push_back(config_t{HALF, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||||
}
|
}
|
||||||
// does the work
|
// does the work
|
||||||
|
dtype_t dtype;
|
||||||
bool AT, BT;
|
bool AT, BT;
|
||||||
int M, N, K, TM, TN, TK, nwarp;
|
int M, N, K, TM, TN, TK, nwarp;
|
||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
||||||
std::cout << "Testing " << c << " ... " << std::flush;
|
std::cout << "Testing " << c << " ... " << std::flush;
|
||||||
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
if(do_test(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||||
std::cout << " Pass! " << std::endl;
|
std::cout << " Pass! " << std::endl;
|
||||||
else{
|
else{
|
||||||
std::cout << " Fail! " << std::endl;
|
std::cout << " Fail! " << std::endl;
|
||||||
|
Reference in New Issue
Block a user