[tests] delete redundant code in dot benchmark and unit tests
This commit is contained in:
@@ -1,76 +1,6 @@
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
inline size_t ceil(size_t x, size_t y) {
|
||||
return (x + y - 1) / y;
|
||||
};
|
||||
|
||||
inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
||||
return [M, N](const rt::function::options_t& x) {
|
||||
return rt::grid_t{ceil(M, x.D<int>("TM")),
|
||||
ceil(N, x.D<int>("TN"))};
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
cublasDataType_t cuty = CUDA_R_32F;
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
// leading dimensions
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
// create inputs
|
||||
auto da = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::unique_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
auto dc = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
// create options
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
std::vector<double> result;
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
// // cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// NumericT alpha(static_cast<double>(1));
|
||||
// NumericT beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// result.push_back(tflops(cublas_ms));
|
||||
// }
|
||||
// triton
|
||||
double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream);
|
||||
result.push_back(tflops(triton_ms));
|
||||
// done
|
||||
return result;
|
||||
}
|
||||
#include "dot.h"
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
@@ -82,7 +12,7 @@ int main() {
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||
{true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||
config_t{x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||
@@ -92,7 +22,7 @@ int main() {
|
||||
// config_t{x[0], x[1], 32, 4096, 4096},
|
||||
// config_t{x[0], x[1], 64, 4096, 4096},
|
||||
// config_t{x[0], x[1], 128, 4096, 4096},
|
||||
// config_t{x[0], x[1], 7000, 4096, 4096},
|
||||
// config_t{x[0], x[1], 7000, 4096, 4096}
|
||||
};
|
||||
configs.insert(configs.end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
@@ -102,7 +32,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush;
|
||||
for(auto perf: do_bench(stream, AT, BT, M, N, K))
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
191
tests/common/dot.h
Normal file
191
tests/common/dot.h
Normal file
@@ -0,0 +1,191 @@
|
||||
#include <iomanip>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
#include "util.h"
|
||||
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void cc_dot(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
cc_dot<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT_ && !BT_)
|
||||
cc_dot<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT_ && BT_)
|
||||
cc_dot<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
cc_dot<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
enum run_mode_t {
|
||||
BENCH,
|
||||
TEST
|
||||
};
|
||||
|
||||
enum dtype_t {
|
||||
FLOAT,
|
||||
HALF,
|
||||
DOUBLE
|
||||
};
|
||||
|
||||
template<class T>
|
||||
struct to_string;
|
||||
|
||||
template<> struct to_string<half_float::half>{
|
||||
static constexpr const char* value = "half";
|
||||
};
|
||||
|
||||
template<> struct to_string<float>{
|
||||
static constexpr const char* value = "float";
|
||||
};
|
||||
|
||||
template<> struct to_string<double>{
|
||||
static constexpr const char* value = "double";
|
||||
};
|
||||
|
||||
template<class T>
|
||||
bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp,
|
||||
run_mode_t mode, std::vector<double>& bench, bool &test){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::context* context = stream->context();
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
|
||||
// inputs
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*K*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, K*N*dt_nbytes));
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
// B access patterns
|
||||
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||
opt.defines.push_back({"STRIDE_BK", {BT? "1" : "ldb" }});
|
||||
opt.defines.push_back({"STRIDE_BN", {BT? "ldb" : "1" }});
|
||||
// A access patterns
|
||||
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||
opt.defines.push_back({"STRIDE_AK", {AT? "lda" : "1" }});
|
||||
opt.defines.push_back({"STRIDE_AM", {AT? "1" : "lda" }});
|
||||
// data-type
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
// tile sizes
|
||||
if(mode == TEST) {
|
||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
}
|
||||
|
||||
// kernels
|
||||
rt::function function(src::dot, opt);
|
||||
std::vector<rt::arg> args = {&*da, &*db, &*dc, M, N, K, lda, ldb, ldc};
|
||||
auto grid = grid2d(M, N);
|
||||
|
||||
// metrics
|
||||
if(mode == BENCH){
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
|
||||
// // cublas
|
||||
// if(cublas::cublasinit()){
|
||||
// NumericT alpha(static_cast<double>(1));
|
||||
// NumericT beta(static_cast<double>(0));
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K,
|
||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||
// ldc, nullptr, fastest); }, stream);
|
||||
// result.push_back(tflops(cublas_ms));
|
||||
// }
|
||||
}
|
||||
|
||||
// test triton
|
||||
if(mode == TEST){
|
||||
srand(0);
|
||||
// initialize buffers
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||
// copy buffer
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
// run kernel
|
||||
function(args, grid, stream);
|
||||
// write back
|
||||
stream->synchronize();
|
||||
// compare with CPU
|
||||
stream->read(&*dc, true, 0, hc);
|
||||
std::vector<T> rc(hc.size());
|
||||
cc_dot(AT, BT, M, N, K, rc, ha, hb);
|
||||
test = testing::diff(hc, rc);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<double> bench_dot(drv::stream* stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K) {
|
||||
std::vector<double> bench;
|
||||
bool test;
|
||||
switch(dtype){
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
default: break;
|
||||
}
|
||||
return bench;
|
||||
}
|
||||
bool test_dot(drv::stream* stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp) {
|
||||
std::vector<double> bench;
|
||||
bool test = false;
|
||||
switch(dtype){
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
default: break;
|
||||
}
|
||||
return test;
|
||||
}
|
@@ -2,38 +2,6 @@ namespace src {
|
||||
|
||||
const char *dot =
|
||||
R"(
|
||||
#if AT == 1
|
||||
#define USEA ^a
|
||||
#define STRIDE_AK lda
|
||||
#define STRIDE_AM 1
|
||||
#define BROADCAST_AK :, newaxis
|
||||
#define BROADCAST_AM newaxis, :
|
||||
#define SHAPE_A TK, TM
|
||||
#else
|
||||
#define USEA a
|
||||
#define STRIDE_AK 1
|
||||
#define STRIDE_AM lda
|
||||
#define BROADCAST_AK newaxis, :
|
||||
#define BROADCAST_AM :, newaxis
|
||||
#define SHAPE_A TM, TK
|
||||
#endif
|
||||
|
||||
#if BT == 1
|
||||
#define USEB ^b
|
||||
#define STRIDE_BK 1
|
||||
#define STRIDE_BN ldb
|
||||
#define BROADCAST_BK newaxis, :
|
||||
#define BROADCAST_BN :, newaxis
|
||||
#define SHAPE_B TN, TK
|
||||
#else
|
||||
#define USEB b
|
||||
#define STRIDE_BK ldb
|
||||
#define STRIDE_BN 1
|
||||
#define BROADCAST_BK :, newaxis
|
||||
#define BROADCAST_BN newaxis, :
|
||||
#define SHAPE_B TK, TN
|
||||
#endif
|
||||
|
||||
void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
|
@@ -1,142 +1,13 @@
|
||||
#include <iomanip>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "src/dot.h"
|
||||
#include "cuda/cublas.h"
|
||||
#include "dot.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
template<class T>
|
||||
void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
for(size_t i = 0; i < x.size(); i++)
|
||||
if(std::isnan(x[i]) || std::abs(x[i] - y[i])/std::max(x[i], y[i]) > 1e-4){
|
||||
std::cout << i << " " << x[i] << " " << y[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
cpu_ref<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT_ && !BT_)
|
||||
cpu_ref<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT_ && BT_)
|
||||
cpu_ref<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
struct to_string;
|
||||
|
||||
template<> struct to_string<half_float::half>{
|
||||
static constexpr const char* value = "half";
|
||||
};
|
||||
|
||||
template<> struct to_string<float>{
|
||||
static constexpr const char* value = "float";
|
||||
};
|
||||
|
||||
template<> struct to_string<double>{
|
||||
static constexpr const char* value = "double";
|
||||
};
|
||||
|
||||
enum dtype_t {
|
||||
FLOAT,
|
||||
HALF,
|
||||
DOUBLE
|
||||
};
|
||||
|
||||
template<class T>
|
||||
bool do_test(drv::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::context* context = stream->context();
|
||||
std::vector<T> hc(M*N);
|
||||
std::vector<T> ha(M*K);
|
||||
std::vector<T> hb(K*N);
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = static_cast<T>((float)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<T>((double)0);
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hc.size()*dt_nbytes));
|
||||
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, ha.size()*dt_nbytes));
|
||||
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, hb.size()*dt_nbytes));
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
stream->write(&*dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// run
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
opt.defines.push_back({"AT", {AT?"1":"0"}});
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {std::to_string(TM)}});
|
||||
opt.defines.push_back({"TN", {std::to_string(TN)}});
|
||||
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||
opt.num_warps = {nwarp};
|
||||
rt::function function(src::dot, opt);
|
||||
try {
|
||||
function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);
|
||||
} catch (const std::runtime_error& e) {
|
||||
return true;
|
||||
}
|
||||
// test
|
||||
stream->read(&*dc, true, 0, hc);
|
||||
std::vector<T> rc(hc.size());
|
||||
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
return testing::diff(hc, rc);
|
||||
}
|
||||
|
||||
bool do_test(triton::driver::stream *stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp) {
|
||||
switch(dtype){
|
||||
case HALF: return do_test<half_float::half>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||
case FLOAT: return do_test<float>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||
case DOUBLE: return do_test<double>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp);
|
||||
default: break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
// shapes to test
|
||||
typedef std::tuple<dtype_t, bool, bool, int, int, int, int, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(int TM: std::vector<int>{32, 64})
|
||||
@@ -147,14 +18,14 @@ int main() {
|
||||
for(bool BT: std::array<bool, 2>{false, true}){
|
||||
configs.push_back(config_t{FLOAT, AT, BT, 128, 128, 128, TM, TN, TK, nwarps});
|
||||
}
|
||||
// does the work
|
||||
// test
|
||||
dtype_t dtype;
|
||||
bool AT, BT;
|
||||
int M, N, K, TM, TN, TK, nwarp;
|
||||
for(const auto& c: configs){
|
||||
std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
||||
std::cout << "Testing " << c << " ... " << std::flush;
|
||||
if(do_test(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||
if(test_dot(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||
std::cout << " Pass! " << std::endl;
|
||||
else{
|
||||
std::cout << " Fail! " << std::endl;
|
||||
|
Reference in New Issue
Block a user