Core: Added double-precision tuning, tests and benchmarks

This commit is contained in:
Philippe Tillet
2016-11-20 22:36:08 -05:00
parent 1f03389be0
commit 54b5b7523d
15 changed files with 184519 additions and 106392 deletions

View File

@@ -22,6 +22,42 @@ typedef sc::int_t int_t;
Timer tmr;
/* C++ wrapper for BLAS */
#ifdef BENCH_CLBLAS
template<typename... Args> void clblasAxpy(float, Args... args){ clblasSaxpy(args...); }
template<typename... Args> void clblasAxpy(double, Args... args){ clblasDaxpy(args...); }
template<typename... Args> void clblasDot(float, Args... args){ clblasSdot(args...); }
template<typename... Args> void clblasDot(double, Args... args){ clblasDdot(args...); }
template<typename... Args> void clblasGemv(float, Args... args){ clblasSgemv(args...); }
template<typename... Args> void clblasGemv(double, Args... args){ clblasDgemv(args...); }
template<typename... Args> void clblasGemm(float, Args... args){ clblasSgemm(args...); }
template<typename... Args> void clblasGemm(double, Args... args){ clblasDgemm(args...); }
#endif
#ifdef BENCH_CBLAS
template<typename... Args> void cblasAxpy(float, Args... args){ cblas_saxpy(args...); }
template<typename... Args> void cblasAxpy(double, Args... args){ cblas_daxpy(args...); }
template<typename... Args> void cblasDot(float, Args... args){ cblas_sdot(args...); }
template<typename... Args> void cblasDot(double, Args... args){ cblas_ddot(args...); }
template<typename... Args> void cblasGemv(float, Args... args){ cblas_sgemv(args...); }
template<typename... Args> void cblasGemv(double, Args... args){ cblas_dgemv(args...); }
template<typename... Args> void cblasGemm(float, Args... args){ cblas_sgemm(args...); }
template<typename... Args> void cblasGemm(double, Args... args){ cblas_dgemm(args...); }
#endif
//cuBLAS
#ifdef BENCH_CUBLAS
template<typename... Args> void cublasAxpy(float, Args... args){ cublasSaxpy(args...); }
template<typename... Args> void cublasAxpy(double, Args... args){ cublasDaxpy(args...); }
template<typename... Args> void cublasDot(float, Args... args){ cublasSdot(args...); }
template<typename... Args> void cublasDot(double, Args... args){ cublasDdot(args...); }
template<typename... Args> void cublasGemv(float, Args... args){ cublasSgemv(args...); }
template<typename... Args> void cublasGemv(double, Args... args){ cublasDgemv(args...); }
template<typename... Args> void cublasGemm(float, Args... args){ cublasSgemm(args...); }
template<typename... Args> void cublasGemm(double, Args... args){ cublasDgemm(args...); }
#endif
//
template<class OP, class SYNC>
double bench(OP const & op, SYNC const & sync)
{
@@ -51,6 +87,10 @@ void bench(sc::numeric_type dtype, std::string operation)
#ifdef BENCH_CUBLAS
auto cusync = [&](){ cudaDeviceSynchronize(); };
#endif
bool on_cl = queue.backend()==sc::driver::OPENCL;
bool on_cu = queue.backend()==sc::driver::CUDA;
/*---------*/
/*--BLAS1--*/
/*---------*/
@@ -65,17 +105,18 @@ void bench(sc::numeric_type dtype, std::string operation)
//Bench
times.push_back(bench([&](){y = x + alpha*y;}, sync));
#ifdef BENCH_CLBLAS
if(x.context().backend()==sc::driver::OPENCL)
times.push_back(bench([&]() {clblasSaxpy(N, alpha, cl(x), 0, 1, cl(y), 0, 1, 1, &cl(queue), 0, NULL, NULL);}, sync));
if(on_cl)
times.push_back(bench([&]() {clblasAxpy(T(), N, alpha, cl(x), 0, 1, cl(y), 0, 1, 1, &cl(queue), 0, nullptr, nullptr);}, sync));
#endif
#ifdef BENCH_CBLAS
std::vector<float> cx(N), cy(N);
sc::copy(x, cx);
sc::copy(y, cy);
times.push_back(bench([&](){cblas_saxpy(N, alpha, cx.data(), 1, cy.data(), 1);}, sync));
times.push_back(bench([&](){cblasAxpy(T(), N, alpha, cx.data(), 1, cy.data(), 1);}, sync));
#endif
#ifdef BENCH_CUBLAS
times.push_back(bench([&](){cublasSaxpy(N, alpha, (T*)cu(x), 1, (T*)cu(y), 1);}, cusync));
if(on_cu)
times.push_back(bench([&](){cublasAxpy(T(), N, alpha, (T*)cu(x), 1, (T*)cu(y), 1);}, cusync));
#endif
}
}
@@ -91,17 +132,18 @@ void bench(sc::numeric_type dtype, std::string operation)
//Bench
times.push_back(bench([&](){s = dot(x,y);}, sync));
#ifdef BENCH_CLBLAS
if(x.context().backend()==sc::driver::OPENCL)
times.push_back(bench([&]() {clblasSdot(N, cl(s), 0, cl(x), 0, 1, cl(y), 0, 1, cl(scratch), 1, &cl(queue), 0, NULL, NULL);}, sync));
if(on_cl)
times.push_back(bench([&]() {clblasDot(T(), N, cl(s), 0, cl(x), 0, 1, cl(y), 0, 1, cl(scratch), 1, &cl(queue), 0, nullptr, nullptr);}, sync));
#endif
#ifdef BENCH_CBLAS
std::vector<float> cx(N), cy(N);
sc::copy(x, cx);
sc::copy(y, cy);
times.push_back(bench([&](){cblas_sdot(N, cx.data(), 1, cy.data(), 1);}, sync));
times.push_back(bench([&](){cblasDot(T(), N, cx.data(), 1, cy.data(), 1);}, sync));
#endif
#ifdef BENCH_CUBLAS
times.push_back(bench([&](){cublasSdot(N, (T*)cu(x), 1, (T*)cu(y), 1);}, cusync));
if(on_cu)
times.push_back(bench([&](){cublasDot(T(), N, (T*)cu(x), 1, (T*)cu(y), 1);}, cusync));
#endif
}
}
@@ -146,18 +188,19 @@ void bench(sc::numeric_type dtype, std::string operation)
//Bench
times.push_back(bench([&](){y = AT?dot(A.T,x):dot(A,x);}, sync));
#ifdef BENCH_CLBLAS
if(x.context().backend()==sc::driver::OPENCL)
times.push_back(bench([&]() {clblasSgemv(clblasColumnMajor, AT?clblasTrans:clblasNoTrans, As1, As2, 1, cl(A), 0, lda, cl(x), 0, 1, 0, cl(y), 0, 1, 1, &cl(queue),0, NULL, NULL);}, sync));
if(on_cl)
times.push_back(bench([&]() {clblasGemv(T(), clblasColumnMajor, AT?clblasTrans:clblasNoTrans, As1, As2, 1, cl(A), 0, lda, cl(x), 0, 1, 0, cl(y), 0, 1, 1, &cl(queue),0, nullptr, nullptr);}, sync));
#endif
#ifdef BENCH_CBLAS
std::vector<float> cA(M*N), cx(N), cy(M);
sc::copy(x, cx);
sc::copy(y, cy);
sc::copy(A, cA);
times.push_back(bench([&](){cblas_sgemv(CblasColMajor, AT?CblasTrans:CblasNoTrans, As1, As2, 1, cA.data(), lda, cx.data(), 1, 0, cy.data(), 1);}, sync));
times.push_back(bench([&](){cblasGemv(T(), CblasColMajor, AT?CblasTrans:CblasNoTrans, As1, As2, 1, cA.data(), lda, cx.data(), 1, 0, cy.data(), 1);}, sync));
#endif
#ifdef BENCH_CUBLAS
times.push_back(bench([&](){cublasSgemv(AT?'t':'n', As1, As2, 1, (T*)cu(A), lda, (T*)cu(x), 1, 0, (T*)cu(y), 1);}, cusync));
if(on_cu)
times.push_back(bench([&](){cublasGemv(T(), AT?'t':'n', As1, As2, 1, (T*)cu(A), lda, (T*)cu(x), 1, 0, (T*)cu(y), 1);}, cusync));
#endif
}
}
@@ -185,12 +228,14 @@ void bench(sc::numeric_type dtype, std::string operation)
std::cout << color_stream(ITALIC) << color_stream(BOLD) ;
std::cout << "BENCH\tM\tN\tK\tAT\tBT\tISAAC";
#ifdef BENCH_CLBLAS
if(on_cl)
std::cout << "\tclBLAS";
#endif
#ifdef BENCH_CBLAS
std::cout << "\tBLAS";
#endif
#ifdef BENCH_CUBLAS
if(on_cu)
std::cout << "\tcuBLAS";
#endif
std::cout << color_stream(RESET) << std::endl;
@@ -224,20 +269,21 @@ void bench(sc::numeric_type dtype, std::string operation)
:(BT?dot(A,B.T)
:dot(A,B));}, sync));
#ifdef BENCH_CLBLAS
if(C.context().backend()==sc::driver::OPENCL)
times.push_back(bench([&]() {clblasSgemm(clblasColumnMajor, AT?clblasTrans:clblasNoTrans, BT?clblasTrans:clblasNoTrans,
if(on_cl)
times.push_back(bench([&]() {clblasGemm(T(), clblasColumnMajor, AT?clblasTrans:clblasNoTrans, BT?clblasTrans:clblasNoTrans,
M, N, K, 1, cl(A), 0, lda, cl(B), 0, ldb,
0, cl(C), 0, ldc, 1, &cl(queue),0, NULL, NULL);}, sync));
0, cl(C), 0, ldc, 1, &cl(queue),0, nullptr, nullptr);}, sync));
#endif
#ifdef BENCH_CBLAS
std::vector<float> cC(M*N), cA(M*K), cB(N*K);
sc::copy(C, cC);
sc::copy(A, cA);
sc::copy(B, cB);
times.push_back(bench([&](){cblas_sgemm(CblasColMajor, AT?CblasTrans:CblasNoTrans, BT?CblasTrans:CblasNoTrans, M, N, K, 1, cA.data(), lda, cB.data(), ldb, 1, cC.data(), ldc);}, sync));
times.push_back(bench([&](){cblasGemm(T(), CblasColMajor, AT?CblasTrans:CblasNoTrans, BT?CblasTrans:CblasNoTrans, M, N, K, 1, cA.data(), lda, cB.data(), ldb, 1, cC.data(), ldc);}, sync));
#endif
#ifdef BENCH_CUBLAS
times.push_back(bench([&](){cublasSgemm(AT?'t':'n', BT?'t':'n', M, N, K, 1, (T*)cu(A), lda, (T*)cu(B), ldb, 1, (T*)cu(C), ldc);}, cusync));
if(on_cu)
times.push_back(bench([&](){cublasGemm(T(), AT?'t':'n', BT?'t':'n', M, N, K, 1, (T*)cu(A), lda, (T*)cu(B), ldb, 1, (T*)cu(C), ldc);}, cusync));
#endif
std::transform(times.begin(), times.end(), std::back_inserter(tflops), [&](double t){ return 2*M*N*K/t*1e-3;});
auto fastest = tflops;
@@ -256,6 +302,34 @@ void bench(sc::numeric_type dtype, std::string operation)
}
void handle_misusage(){
std::cerr << "Usage : blas-bench [--dtype {float32, float64}] [--device DEVICE_IDX] [--help]" << std::endl;
// std::cerr << "--op: operation to benchmark" << std::endl;
std::cerr << "--dtype: data-type to benchmark" << std::endl;
std::cerr << "--device: index of isaac device in [0, ..., ndevices - 1]" << std::endl;
std::cerr << "--help: display this message" << std::endl;
exit(EXIT_FAILURE);
}
std::string getopt(std::vector<std::string> const & args,
std::string const & key,
std::vector<std::string> const & set = {},
std::string dft = "")
{
auto it = std::find(args.begin(), args.end(), key);
if(it==args.end()){
if(dft.empty())
handle_misusage();
return dft;
}
auto next = it + 1;
if(next==args.end() || next->compare(0, 2, "--")==0)
handle_misusage();
if(set.size() && std::find(set.begin(), set.end(), *next)==set.end())
handle_misusage();
return *next;
}
int main(int argc, char* argv[])
{
std::vector<std::string> args(argv, argv + argc);
@@ -264,41 +338,38 @@ int main(int argc, char* argv[])
#endif
sc::driver::backend::default_queue_properties = CL_QUEUE_PROFILING_ENABLE;
int device_idx = 0;
std::list<sc::driver::Context const *> contexts;
sc::driver::backend::contexts::get(contexts);
if(std::find(args.begin(), args.end(), "--help") != args.end())
handle_misusage();
std::string operation;
if(contexts.size() > 1)
{
if(args.size() != 3)
{
std::cerr << "usage : blas-bench DEVICE_IDX OPERATION" << std::endl;
std::cout << "Devices available: " << std::endl;
unsigned int current=0;
for(sc::driver::Context const * context: contexts)
{
sc::driver::Device device = sc::driver::backend::queues::get(*context,0).device();
std::cout << current++ << ": " << device.name() << " on " << device.platform().name() << " " << device.platform().version() << std::endl;
}
exit(EXIT_FAILURE);
}
device_idx = atoi(argv[1]);
operation = args[2];
}
else
{
if(args.size() != 2)
{
std::cerr << "usage : blas-bench OPERATION" << std::endl;
exit(EXIT_FAILURE);
}
operation = args[1];
}
std::string operation = "gemm";
std::string dtype = getopt(args, "--dtype", {"float32", "float64"}, "float32");
int device;
try{
device = std::stoi(getopt(args, "--device", {}, "0"));
}catch(...){ handle_misusage(); }
sc::driver::backend::default_device = device;
/* List devices */
std::cout << "Devices available:" << std::endl;
std::cout << "------------------" << std::endl;
size_t i = 0;
std::vector<sc::driver::Platform> platforms;
sc::driver::backend::platforms(platforms);
for(sc::driver::Platform const & pf: platforms){
std::vector<sc::driver::Device> devices;
pf.devices(devices);
for(sc::driver::Device const & device: devices)
std::cout << "[" << (i++==sc::driver::backend::default_device?"x":" ") << "]"
<< " - " << device.name()
<< " on " << pf.name() << std::endl;
}
std::cout << "------------------" << std::endl;
sc::driver::backend::default_device = device_idx;
std::cout << std::fixed << std::setprecision(2);
if(dtype=="float32")
bench<float>(sc::FLOAT_TYPE, operation);
else
bench<double>(sc::DOUBLE_TYPE, operation);
#ifdef BENCH_CLBLAS
clblasTeardown();

BIN
build/bench/bench-blas Executable file

Binary file not shown.

View File

@@ -134,7 +134,7 @@ public:
data_type const & data() const;
std::size_t root() const;
driver::Context const & context() const;
numeric_type const & dtype() const;
numeric_type dtype() const;
node const & operator[](size_t) const;
node & operator[](size_t);

View File

@@ -116,7 +116,7 @@ std::size_t expression_tree::root() const
driver::Context const & expression_tree::context() const
{ return *context_; }
numeric_type const & expression_tree::dtype() const
numeric_type expression_tree::dtype() const
{ return tree_[root_].dtype; }
tuple expression_tree::shape() const

View File

@@ -48,7 +48,7 @@ namespace runtime
const profiles::presets_type profiles::presets_ =
{
//DEFAULT
DATABASE_ENTRY(UNKNOWN, UNKNOWN, UNKNOWN, database::unknown::unknown),
DATABASE_ENTRY(UNKNOWN, UNKNOWN, UNKNOWN, database::nvidia::sm_6_1),
//INTEL
DATABASE_ENTRY(GPU, INTEL, BROADWELL, database::intel::broadwell),
//NVIDIA

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -87,11 +87,11 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
//Cached
size_t label = 0;
auto it = labels_.find(x);
if(it!=labels_.end()){
if(it!=labels_.end())
label = it->second;
}
//Not cached
else if(predictor_){
else if(predictor_)
{
expression_tree::node const & root = tree[tree.root()];
expression_tree::node const & left = tree[root.binary_operator.lhs];
array_base* out = left.array.base;
@@ -100,12 +100,14 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
};
bool modify_output = std::find_if(tree.data().begin(), tree.data().end(), read_out) != tree.data().end();
std::unique_ptr<array> bkp;
if(modify_output)
bkp.reset(new array(*out));
if(modify_output){
bkp.reset(new array(out->shape(), out->dtype(), queue_.context()));
*bkp = execution_handler(-(-*out), execution_options_type(queue_));
}
tools::Timer tmr;
std::vector<double> times;
std::vector<float> perf = predictor_->predict(x);
std::vector<size_t> idx(templates_.size());
std::vector<size_t> idx(perf.size());
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(), [&perf](size_t i1, size_t i2) {return perf[i1] > perf[i2];});
bool valid_found = false;
@@ -133,10 +135,9 @@ void profiles::value_type::execute(runtime::execution_handler const & expr)
}
label = idx[std::distance(times.begin(),std::min_element(times.begin(), times.end()))];
if(modify_output)
*out = *bkp;
}
*out = execution_handler(-(-*bkp), execution_options_type(queue_));
labels_.insert({x, label});
//Executes
}
if(templates_[label]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE)
throw operation_not_supported_exception("Running this operation would require an overly large temporary.");
templates_[label]->enqueue(queue_, program, tools::to_string(label), expr);

View File

@@ -309,6 +309,7 @@ void export_core()
ADD_ARRAY_OPERATOR(==)
ADD_ARRAY_OPERATOR(!=)
.add_property("context", bp::make_function(&sc::expression_tree::context, bp::return_internal_reference<>()))
.add_property("dtype", &sc::expression_tree::dtype)
.def(bp::self_ns::abs(bp::self))
// .def(bp::self_ns::pow(bp::self))
;

View File

@@ -261,7 +261,7 @@ bool diff(isaac::array const & x, VecType const & y, typename VecType::value_typ
{START1, START1 + STRIDE1*SUBM, STRIDE1}));\
template<typename test_fun_t>
int run_test(test_fun_t const & testf, test_fun_t const & /*testd*/)
int run_test(test_fun_t const & testf, test_fun_t const & testd)
{
int nfail = 0;
int npass = 0;
@@ -275,8 +275,8 @@ int run_test(test_fun_t const & testf, test_fun_t const & /*testd*/)
std::cout << "Device: " << device.name() << " on " << device.platform().name() << " " << device.platform().version() << std::endl;
std::cout << "---" << std::endl;
testf(*context, nfail, npass);
// if(device.fp64_support())
// testd(*context, nfail, npass);
if(device.fp64_support())
testd(*context, nfail, npass);
std::cout << "---" << std::endl;
}
if(nfail>0)

View File

@@ -23,6 +23,8 @@ from math import ceil, exp, log, sqrt
import time
profile_execution_failure = (sc.OperationNotSupported, sc.OclLaunchOutOfResources, sc.CudaLaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue)
dtype=sc.float32
def sanitize(string, keep_chars = ['_']):
string = string.replace(' ', '_').replace('-', '_').lower()
string = "".join(c for c in string if c.isalnum() or c in keep_chars).rstrip()
@@ -43,7 +45,7 @@ def expspace(a,b,N,r=128):
def benchmark(template, tree, operation=sc.templates.gemm_nn):
queue = tree.context.queues[0]
queue.profiles[template, sc.float32] = sc.profile(template, sc.float32, queue)
queue.profiles[template, dtype] = sc.profile(template, dtype, queue)
times = []
total = 0
#Warm-up
@@ -66,33 +68,33 @@ def benchmark(template, tree, operation=sc.templates.gemm_nn):
def tree_of(template, sizes, context):
if issubclass(template, sc.templates.elementwise_1d):
N, = sizes
x = sc.empty(N, dtype=sc.float32, context=context)
y = sc.empty(N, dtype=sc.float32, context=context)
x = sc.empty(N, dtype=dtype, context=context)
y = sc.empty(N, dtype=dtype, context=context)
return sc.assign(y, x + y), (x, y)
elif issubclass(template, sc.templates.reduce_1d):
N, = sizes
x = sc.empty(N, context=context)
y = sc.empty(N, context=context)
x = sc.empty(N, dtype=dtype, context=context)
y = sc.empty(N, dtype=dtype, context=context)
return sc.dot(x, y), (x, y)
elif issubclass(template, sc.templates.elementwise_2d):
M, N = sizes
A = sc.empty((M,N), context=context)
B = sc.empty((M,N), context=context)
A = sc.empty((M,N), dtype=dtype, context=context)
B = sc.empty((M,N), dtype=dtype, context=context)
return A + B, (A, B)
elif issubclass(template, sc.templates.reduce_2d):
T = template is sc.templates.reduce_2d_cols
M, N = sizes[::-1] if T else sizes
A = sc.empty((M,N), context=context)
x = sc.empty(N, context=context)
y = sc.empty(M, context=context)
A = sc.empty((M,N), dtype=dtype, context=context)
x = sc.empty(N, dtype=dtype, context=context)
y = sc.empty(M, dtype=dtype, context=context)
return sc.assign(x, sc.dot(A.T, y)) if T else sc.assign(y, sc.dot(A, x)), (A, x, y)
elif issubclass(template, sc.templates.gemm):
AT = template is sc.templates.gemm_tn or template is sc.templates.gemm_tt
BT = template is sc.templates.gemm_nt or template is sc.templates.gemm_tt
M, N, K = sizes
C = sc.empty((M,N), context=context)
A = sc.empty((K, M) if AT else (M, K), context=context)
B = sc.empty((N, K) if BT else (K, N), context=context)
C = sc.empty((M,N), dtype=dtype, context=context)
A = sc.empty((K, M) if AT else (M, K), dtype=dtype, context=context)
B = sc.empty((N, K) if BT else (K, N), dtype=dtype, context=context)
AA = A.T if AT else A
BB = B.T if BT else B
return sc.assign(C, sc.dot(AA, BB)), (A, B, C)

View File

@@ -46,10 +46,11 @@ def pow2range(a, b):
class Tuner:
def __init__(self, logger, device, operation, json_path, progress_bar):
def __init__(self, logger, device, operation, dtype, json_path, progress_bar):
self.logger = logger
self.device = device
self.operation = operation
self.dtype = dtype
self.json_path = json_path
self.progress_bar = progress_bar
@@ -57,7 +58,7 @@ class Tuner:
def run(self, level = 'intermediate'):
assert level in ['simple', 'intermediate', 'full']
tools.dtype = self.dtype
device = self.device
operation = self.operation
context = sc.driver.context(device)
@@ -65,6 +66,7 @@ class Tuner:
if self.logger:
self.logger.info("----------------")
self.logger.info(operation.__name__.replace('_','-').upper())
self.logger.info(tools.dtype.__name__.upper())
self.logger.info("----------------")
#BLAS1 training sizes
@@ -116,7 +118,7 @@ class Tuner:
profiles, X, Y = [], [], []
#Restore progress
savepath = os.path.join('save', operation.__name__)
savepath = os.path.join('save', tools.dtype.__name__, operation.__name__)
if not os.path.exists(savepath):
os.makedirs(savepath)
try:
@@ -202,8 +204,8 @@ class Tuner:
operation_name = operation.__name__
if operation_name not in json_data:
json_data[operation_name] = {}
json_data[operation_name]['float32'] = {}
D = json_data[operation_name]['float32']
json_data[operation_name][tools.dtype.__name__] = {}
D = json_data[operation_name][tools.dtype.__name__]
if len(profiles) > 1:
clf, nrmse = model.train(X, Y, profiles)
D['predictor'] = [{'children_left': e.tree_.children_left.tolist(),

View File

@@ -30,6 +30,8 @@ def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default=0, type=int, help='Device to tune for')
parser.add_argument("-j", "--json", default='', type=str)
parser.add_argument('--float32', action='store_true', help='Tune 32-bits FLOAT')
parser.add_argument('--float64', action='store_true', help='Tune 64-bits FLOAT')
parser.add_argument('--elementwise_1d', action='store_true', help='Tune ELEMENTWISE [1D]')
parser.add_argument('--elementwise_2d', action='store_true', help='Tune ELEMENTWISE [2D]')
parser.add_argument('--reduce_1d', action='store_true', help='Tune REDUCE [1D]')
@@ -54,7 +56,12 @@ def parse_arguments():
operations = ['elementwise_1d', 'reduce_1d', 'elementwise_2d', 'reduce_2d_rows', 'reduce_2d_cols', 'gemm_nn', 'gemm_tn', 'gemm_nt', 'gemm_tt']
operations = [getattr(sc.templates,op) for op in operations if getattr(args, op)]
return (device, operations, args.json)
#Dtypes
dtypes = []
if args.float32: dtypes+=[sc.float32]
if args.float64: dtypes+=[sc.float64]
return (device, operations, dtypes, args.json)
class ProgressBar:
@@ -91,8 +98,9 @@ if __name__ == "__main__":
logger.setLevel(logging.INFO)
sc.driver.default.queue_properties = sc.driver.PROFILING_ENABLE
device, operations, json = parse_arguments()
device, operations, dtypes, json = parse_arguments()
for operation in operations:
tuner = Tuner(logger, device, operation, json, ProgressBar(30, metric_name_of(operation)))
for dtype in dtypes:
tuner = Tuner(logger, device, operation, dtype, json, ProgressBar(30, metric_name_of(operation)))
tuner.run(level='intermediate')