CUDA: various improvements

This commit is contained in:
Philippe Tillet
2015-08-21 13:06:20 -04:00
parent 33dac6b05a
commit 10524ebdee
25 changed files with 170 additions and 130 deletions

View File

@@ -39,12 +39,14 @@ string(REPLACE ";" " " BLAS_DEF_STR "${BLAS_DEF}")
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
foreach(PROG blas)
if(CUDA_FOUND)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} " ${BLAS_DEF_STR} -std=c++11 ${BACKEND_DEFINES}")
set(OLD_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "")
set(CUPROG ${CMAKE_CURRENT_BINARY_DIR}/${PROG}.cu)
file(COPY ${PROG}.cpp DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${PROG}.cpp ${CUPROG})
cuda_add_executable(${PROG}-bench ${CUPROG})
cuda_add_cublas_to_target(${PROG}-bench)
cuda_add_executable(${PROG}-bench ${CUPROG} OPTIONS "-std=c++11 ${BLAS_DEF_STR} ${BACKEND_DEFINES}")
cuda_add_cublas_to_target(${PROG}-bench)
set(CMAKE_CXX_FLAGS "${OLD_CXX_FLAGS}")
else()
add_executable(${PROG}-bench ${PROG}.cpp)
set_target_properties(${PROG}-bench PROPERTIES COMPILE_FLAGS "${BLAS_DEF_STR}")

View File

@@ -94,7 +94,7 @@ void bench(sc::numeric_type dtype, std::string operation)
cudaEventCreate(&stop);\
OP;\
cudaThreadSynchronize();\
while(total_time*1e-3 < 1e-3){\
while(total_time*1e-3 < 1e-2){\
flush = sc::zeros(1e6, 1, dtype);\
cudaEventRecord(start,0);\
OP;\
@@ -191,7 +191,6 @@ void bench(sc::numeric_type dtype, std::string operation)
#endif
#ifdef BENCH_CUBLAS
T *cux, *cuy;
T result;
cudaMalloc((void**) &cux, N * sizeof(T));
cudaMalloc((void**) &cuy, N * sizeof(T));
BENCHMARK_CUDA(cublasSdot(N, cux, 1, cuy, 1), 2*N*dtsize/t)
@@ -210,6 +209,7 @@ void bench(sc::numeric_type dtype, std::string operation)
//AlexNet
MNs.push_back(std::make_tuple('N',1000,256));
MNs.push_back(std::make_tuple('N',4096,256));
MNs.push_back(std::make_tuple('T',169,256));
MNs.push_back(std::make_tuple('T',169,384));
MNs.push_back(std::make_tuple('T',729,256));
@@ -261,39 +261,43 @@ void bench(sc::numeric_type dtype, std::string operation)
if(operation.substr(0,4)=="gemm")
{
std::vector<std::tuple<std::string, char, char, int_t, int_t, int_t> > MNKs;
MNKs.push_back(std::make_tuple("Square [512]",'N','T',512,512,512));
MNKs.push_back(std::make_tuple("Square [1536]",'N','T',1536,1536,1536));
//AlexNet (Forward)
MNKs.push_back(std::make_tuple("F-Conv1",'N','N',3025,96,363));
MNKs.push_back(std::make_tuple("F-Conv2",'N','N',729,128,1200));
MNKs.push_back(std::make_tuple("F-Conv3",'N','N',169,384,2304));
MNKs.push_back(std::make_tuple("F-Conv4",'N','N',169,192,1728));
MNKs.push_back(std::make_tuple("F-Conv5",'N','N',169,128,1728));
//LeNet (Forward)
MNKs.push_back(std::make_tuple("F-Conv1",'N','N',576,20,25));
MNKs.push_back(std::make_tuple("F-Conv2",'N','N',64,50,500));
//Square
MNKs.push_back(std::make_tuple("Square [N=896]",'N','T',896,896,896));
MNKs.push_back(std::make_tuple("Square [N=2560]",'N','T',2560,2560,2560));
//AlexNet (Backward)
MNKs.push_back(std::make_tuple("B-Conv5",'T','N',1728,128,169));
MNKs.push_back(std::make_tuple("B-Conv4",'T','N',1728,192,169));
MNKs.push_back(std::make_tuple("B-Conv3",'T','N',2304,384,169));
MNKs.push_back(std::make_tuple("B-Conv2",'T','N',1200,128,729));
MNKs.push_back(std::make_tuple("B-Conv1",'T','N',363,96,3025));
//LeNet (Backward)
MNKs.push_back(std::make_tuple("B-Conv2",'T','N',500,50,64));
MNKs.push_back(std::make_tuple("B-Conv1",'T','N',25,20,576));
//Convolution
MNKs.push_back(std::make_tuple("Convolution [AlexNet-1]",'N','N',3025,96,363));
MNKs.push_back(std::make_tuple("Convolution [AlexNet-2]",'N','N',729,128,1200));
MNKs.push_back(std::make_tuple("Convolution [AlexNet-3]",'N','N',169,384,2304));
MNKs.push_back(std::make_tuple("Convolution [AlexNet-4]",'N','N',169,192,1728));
MNKs.push_back(std::make_tuple("Convolution [AlexNet-5]",'N','N',169,128,1728));
// MNKs.push_back(std::make_tuple("Convolution [LeNet-1],'N','N',576,20,25));
// MNKs.push_back(std::make_tuple("Convolution [LeNet-2]",'N','N',64,50,500));
//Convolution Gradient-1
MNKs.push_back(std::make_tuple("Convolution Gradient-1 [AlexNet-5]",'T','N',1728,128,169));
MNKs.push_back(std::make_tuple("Convolution Gradient-1 [AlexNet-4]",'T','N',1728,192,169));
MNKs.push_back(std::make_tuple("Convolution Gradient-1 [AlexNet-3]",'T','N',2304,384,169));
MNKs.push_back(std::make_tuple("Convolution Gradient-1 [AlexNet-2]",'T','N',1200,128,729));
MNKs.push_back(std::make_tuple("Convolution Gradient-1 [AlexNet-1]",'T','N',363,96,3025));
// MNKs.push_back(std::make_tuple("Conv. Gradient-1 [LeNet-2]",'T','N',500,50,64));
// MNKs.push_back(std::make_tuple("Conv. Gradient-1 [LeNet-1]",'T','N',25,20,576));
MNKs.push_back(std::make_tuple("B-Conv5 [bottom]",'N','T',169,1728,128));
MNKs.push_back(std::make_tuple("B-Conv4 [bottom]",'N','T',169,1728,192));
MNKs.push_back(std::make_tuple("B-Conv3 [bottom]",'N','T',169,2304,384));
MNKs.push_back(std::make_tuple("B-Conv2 [bottom]",'N','T',729,1200,128));
//LeNet (Backward)
MNKs.push_back(std::make_tuple("B-Conv2 [bottom]",'N','T',64,500,50));
//Convolution Gradient-2
MNKs.push_back(std::make_tuple("Convolution Gradient-2 [AlexNet-5]",'N','T',169,1728,128));
MNKs.push_back(std::make_tuple("Convolution Gradient-2 [AlexNet-4]",'N','T',169,1728,192));
MNKs.push_back(std::make_tuple("Convolution Gradient-2 [AlexNet-3]",'N','T',169,2304,384));
MNKs.push_back(std::make_tuple("Convolution Gradient-2 [AlexNet-2]",'N','T',729,1200,128));
// MNKs.push_back(std::make_tuple("Conv. Gradient-2 [LeNet-2]",'N','T',64,500,50));
//Covariance (e.g., ICA)
MNKs.push_back(std::make_tuple("ICA [32 channels]",'N','N',32,32,32000));
MNKs.push_back(std::make_tuple("ICA [256 channels]",'N','N',256,256,32000));
//Covariance (e.g., ICA, 10minutes/1khz)
MNKs.push_back(std::make_tuple("ICA [32 channels]",'N','T',32,32,600000));
MNKs.push_back(std::make_tuple("ICA [256 channels]",'N','T',256,256,600000));
//Bi-diagonalization
MNKs.push_back(std::make_tuple("Bidiagonalization [Iteration 1]",'N','T',4096,4096,32));
MNKs.push_back(std::make_tuple("Bidiagonalization [Iteration 10]",'N','T',3456,3456,32));
MNKs.push_back(std::make_tuple("Bidiagonalization [Iteration 50]",'N','T',896,896,32));
/*---------*/
/*--BLAS3--*/
@@ -317,7 +321,7 @@ void bench(sc::numeric_type dtype, std::string operation)
#ifdef HAS_A_BLAS
int_t lda = A.ld(), ldb = B.ld(), ldc = C.ld();
#endif
BENCHMARK_ISAAC(C = sc::control(AT?(BT?dot(A.T(),B.T()):dot(A.T(),B)):(BT?dot(A,B.T()):dot(A,B)), sc::execution_options_type(0, &events)), (double)2*M*N*K/t);
BENCHMARK_ISAAC(C = sc::control(AT?(BT?dot(A.T(),B.T()):dot(A.T(),B)):(BT?dot(A,B.T()):dot(A,B)), sc::execution_options_type(0, &events), sc::dispatcher_options_type(false)), (double)2*M*N*K/t);
/* clblas */
#ifdef BENCH_CLBLAS
BENCHMARK_CLBLAS(clblasSgemm(clblasColumnMajor, AT?clblasTrans:clblasNoTrans, BT?clblasTrans:clblasNoTrans, M, N, K, 1, CL_HANDLE(A.data()), 0, lda, CL_HANDLE(B.data()), 0, ldb,
@@ -336,7 +340,7 @@ void bench(sc::numeric_type dtype, std::string operation)
cudaMalloc((void**) &cuA, M * K * sizeof(T));
cudaMalloc((void**) &cuB, K * N * sizeof(T));
cudaMalloc((void**) &cuC, M * N * sizeof(T));
BENCHMARK_CUDA(cublasSgemm('n', 't', M, N, K, 1, cuA, lda, cuB, ldb, 1, cuC, ldc), (double)2*M*N*K/t)
BENCHMARK_CUDA(cublasSgemm(AT?'t':'n', BT?'t':'n', M, N, K, 1, cuA, lda, cuB, ldb, 1, cuC, ldc), (double)2*M*N*K/t)
cudaFree(cuA);
cudaFree(cuB);
cudaFree(cuC);

View File

@@ -143,7 +143,7 @@ def main():
libraries=libraries)]
#External
extensions += [Extension('autotuning.external.sklearn._tree',
extensions += [Extension('external.sklearn._tree',
['external/sklearn/_tree.c'],
include_dirs = [numpy_include])]
@@ -155,7 +155,7 @@ def main():
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
license='MPL 2.0',
packages=['isaac','isaac.autotuning', 'isaac.autotuning.external', 'isaac.autotuning.external.deap', 'isaac.autotuning.external.deap.tools', 'isaac.autotuning.external.sklearn'],
packages=['isaac','isaac.external', 'isaac.external.sklearn'],
ext_package="isaac",
ext_modules=extensions,
cmdclass={'build_py': build_py, 'build_ext': build_ext_subclass},

View File

@@ -18,10 +18,10 @@
#define RESTORE_MSVC_WARNING_C4275 __pragma(warning(disable: 4275))
#else
#define DISABLE_MSVC_WARNING_C4251
#define RESTORE_MSVC_WARNING_C4251
#define DISABLE_MSVC_WARNING_C4275
#define RESTORE_MSVC_WARNING_C4275
#define DISABLE_MSVC_WARNING_C4251
#define RESTORE_MSVC_WARNING_C4251
#define DISABLE_MSVC_WARNING_C4275
#define RESTORE_MSVC_WARNING_C4275
#endif
#endif

View File

@@ -19,7 +19,7 @@ namespace isaac
namespace driver
{
enum ISAACAPI backend_type
enum backend_type
{
OPENCL
#ifdef ISAAC_WITH_CUDA

View File

@@ -20,7 +20,7 @@ private:
friend class CommandQueue;
public:
enum ISAACAPI Type
enum Type
{
GPU = CL_DEVICE_TYPE_GPU,
CPU = CL_DEVICE_TYPE_CPU,

View File

@@ -15,10 +15,6 @@ namespace driver
class ISAACAPI Event
{
friend class CommandQueue;
private:
#ifdef ISAAC_WITH_CUDA
typedef std::pair<CUevent, CUevent> cu_event_t;
#endif
public:
Event(cl_event const & event, bool take_ownership = true);
Event(backend_type backend);

View File

@@ -12,6 +12,15 @@ namespace isaac
namespace driver
{
#ifdef ISAAC_WITH_CUDA
struct cu_event_t{
operator bool() const { return first && second; }
CUevent first;
CUevent second;
};
#endif
#ifdef ISAAC_WITH_CUDA
#define HANDLE_TYPE(CLTYPE, CUTYPE) Handle<CLTYPE, CUTYPE>
#else
@@ -30,7 +39,7 @@ private:
static void _delete(CUevent x);
static void _delete(CUfunction);
static void _delete(CUmodule x);
static void _delete(std::pair<CUevent, CUevent> x);
static void _delete(cu_event_t x);
#endif
static void release(cl_context x);

View File

@@ -21,7 +21,7 @@ public:
virtual ~symbolic_binder();
virtual bool bind(driver::Buffer const &) = 0;
virtual unsigned int get(driver::Buffer const &) = 0;
virtual unsigned int get();
unsigned int get();
protected:
unsigned int current_arg_;
std::map<driver::Buffer,unsigned int> memory;

View File

@@ -70,6 +70,7 @@ struct CastPrefix: public keyword{ CastPrefix(driver::backend_type backend, std:
struct InitPrefix: public keyword{ InitPrefix(driver::backend_type backend, std::string const & datatype): keyword(backend, "", "make_" + datatype){} };
struct Infinity: public keyword{ Infinity(driver::backend_type backend, std::string const & datatype): keyword(backend, "INFINITY", "infinity<" + datatype + ">()"){} };
struct Select: public keyword{ Select(driver::backend_type backend, std::string cond, std::string if_value, std::string else_value): keyword(backend, "select(" + else_value + "," + if_value + "," + cond + ")", "(" + cond + ")?" + if_value + ":" + else_value) {} };
#undef ADD_KEYWORD

View File

@@ -1,9 +1,9 @@
#ifndef ISAAC_VALUE_SCALAR_H
#define ISAAC_VALUE_SCALAR_H
#include <inttypes.h>
#include "isaac/defines.h"
#include "isaac/types.h"
#include <stdint.h>
namespace isaac
{

View File

@@ -130,12 +130,12 @@ void backend::platforms(std::vector<Platform> & platforms)
#ifdef ISAAC_WITH_CUDA
platforms.push_back(Platform(CUDA));
#endif
cl_uint nplatforms;
ocl::check(clGetPlatformIDs(0, NULL, &nplatforms));
std::vector<cl_platform_id> clplatforms(nplatforms);
ocl::check(clGetPlatformIDs(nplatforms, clplatforms.data(), NULL));
for(cl_platform_id p: clplatforms)
platforms.push_back(Platform(p));
// cl_uint nplatforms;
// ocl::check(clGetPlatformIDs(0, NULL, &nplatforms));
// std::vector<cl_platform_id> clplatforms(nplatforms);
// ocl::check(clGetPlatformIDs(nplatforms, clplatforms.data(), NULL));
// for(cl_platform_id p: clplatforms)
// platforms.push_back(Platform(p));
}
void backend::synchronize(Context const & context)

View File

@@ -44,7 +44,7 @@ long Event::elapsed_time() const
}
}
HANDLE_TYPE(cl_event, Event::cu_event_t) & Event::handle()
HANDLE_TYPE(cl_event, cu_event_t) & Event::handle()
{ return h_; }
}

View File

@@ -19,7 +19,7 @@ template<class CLType, class CUType>
void Handle<CLType, CUType>::_delete(CUstream x) { cuStreamDestroy(x); }
template<class CLType, class CUType>
void Handle<CLType, CUType>::_delete(CUdevice) { }
void Handle<CLType, CUType>::_delete(CUdevice) { std::cout << "CUdevice" << std::endl;}
template<class CLType, class CUType>
void Handle<CLType, CUType>::_delete(CUevent x) { cuEventDestroy(x); }
@@ -31,7 +31,7 @@ template<class CLType, class CUType>
void Handle<CLType, CUType>::_delete(CUmodule x) { cuModuleUnload(x); }
template<class CLType, class CUType>
void Handle<CLType, CUType>::_delete(std::pair<CUevent, CUevent> x) { _delete(x.first); _delete(x.second); }
void Handle<CLType, CUType>::_delete(cu_event_t x) { _delete(x.first); _delete(x.second); }
#endif
@@ -100,8 +100,9 @@ template<class CLType, class CUType>
Handle<CLType, CUType>::~Handle()
{
#ifdef ISAAC_WITH_CUDA
if(has_ownership_ && cu_.unique())
if(has_ownership_ && cu_ && cu_.unique() && *cu_){
_delete(*cu_);
}
#endif
if(has_ownership_ && cl_ && cl_.unique() && *cl_)
release(*cl_);
@@ -132,7 +133,7 @@ template class Handle<cl_mem, CUdeviceptr>;
template class Handle<cl_command_queue, CUstream>;
template class Handle<cl_context, CUcontext>;
template class Handle<cl_device_id, CUdevice>;
template class Handle<cl_event, std::pair<CUevent, CUevent> >;
template class Handle<cl_event, cu_event_t>;
template class Handle<cl_kernel, CUfunction>;
template class Handle<cl_program, CUmodule>;
#else

View File

@@ -10,6 +10,8 @@ namespace driver
NDRange::NDRange(size_t size0)
{
sizes_[0] = size0;
sizes_[1] = 1;
sizes_[2] = 1;
dimension_ = 1;
}
@@ -17,6 +19,7 @@ NDRange::NDRange(size_t size0, size_t size1)
{
sizes_[0] = size0;
sizes_[1] = size1;
sizes_[2] = 1;
dimension_ = 2;
}

View File

@@ -12,7 +12,11 @@ namespace driver
{
#ifdef ISAAC_WITH_CUDA
Platform::Platform(backend_type backend): backend_(backend){}
Platform::Platform(backend_type backend): backend_(backend)
{
if(backend==CUDA)
cuInit(0);
}
#endif
Platform::Platform(cl_platform_id const & platform) : backend_(OPENCL)

View File

@@ -56,8 +56,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
return 0;
}
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
int gemm::is_invalid_impl(driver::Device const & device, expressions_tuple const &) const
{
if(device.vendor()==driver::Device::Vendor::NVIDIA && p_.simd_width > 1)
return TEMPLATE_INVALID_SIMD_WIDTH;
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
@@ -170,8 +173,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << Local(backend) << " " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << "__global " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << Global(backend) << " " << sdtype << "* Ai[" << npA << "];" << std::endl;
stream << Global(backend) << " " << sdtype << "* Bi[" << npB << "];" << std::endl;
stream << std::endl;
stream << "//identifiers" << std::endl;
@@ -179,7 +182,11 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "int idt;" << std::endl;
if(has_depth)
stream << "int gidz, div, offz;" << std::endl;
stream << "int4 ids = (int4)(" << GroupIdx0(backend) << "," << GroupIdx1(backend) << "," << LocalIdx0(backend) << "," << LocalIdx1(backend) << ");" << std::endl;
stream << "uint4 ids;" << std::endl;
stream << "ids.x = " << GroupIdx0(backend) << ";" << std::endl;
stream << "ids.y = " << GroupIdx1(backend) << ";" << std::endl;
stream << "ids.z = " << LocalIdx0(backend) << ";" << std::endl;
stream << "ids.w = " << LocalIdx1(backend) << ";" << std::endl;
stream << std::endl;
stream << "//offsets" << std::endl;
@@ -266,15 +273,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(unsigned int i = 0 ; i < npA ; i++ )
if (A_trans_=='N')
stream << "Ai[" << i << "] += select(0, (int)((idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << ASTRIDE1 << "), " << i*p_.local_fetch_0*p_.simd_width << " < M);" << std::endl;
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_0*p_.simd_width) + " < M", "(int)((idT.x + " + to_string(i*p_.local_fetch_0*p_.simd_width) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
else
stream << "Ai[" << i << "] += select(0, (int)((idT.y + " << i*p_.local_fetch_1 << ")*lda), " << i*p_.local_fetch_1 << " < M);" << std::endl;
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_1) + " < M", "(int)((idT.y + " + to_string(i*p_.local_fetch_1) + ")*lda)", "0") << ";" << std::endl;
for(unsigned int i = 0 ; i < npB ; i++ )
if (B_trans_=='T')
stream << "Bi[" << i << "] += select(0, (int)((idT.x + " << i*p_.local_fetch_0*p_.simd_width << ")" << BSTRIDE1 << "), " << i*p_.local_fetch_0*p_.simd_width << " < N);" << std::endl;
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_0*p_.simd_width) + " < N", "(int)((idT.x + " + to_string(i*p_.local_fetch_0*p_.simd_width) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
else
stream << "Bi[" << i << "] += select(0, (int)((idT.y + " << i*p_.local_fetch_1 << ")*ldb), " << i*p_.local_fetch_1 << " < N);" << std::endl;
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.local_fetch_1) + " < N", "(int)((idT.y + " + to_string(i*p_.local_fetch_1) + ")*ldb)", "0") << ";" << std::endl;
stream << std::endl;
stream << "//Outer loop" << std::endl;
@@ -504,14 +511,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream << "N -= ids.y;" << std::endl;
stream << "N -= ids.w*" << p_.simd_width << ";" << std::endl;
stream << "int ibm[" << p_.mS << "];" << std::endl;
for(unsigned int m=0; m < p_.mS; ++m)
{
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
stream << "ibm[" << m << "] = " << Ci << " < M;" << std::endl;
}
for(unsigned int n=0; n < p_.nS; ++n)
{
string Cj = to_string((n/p_.simd_width)*(p_.local_size_1*p_.simd_width) + n%p_.simd_width);
@@ -521,13 +520,15 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
for(unsigned int m=0; m < p_.mS; ++m)
{
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
stream << "if(ibm[" << m << "]) ";
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + select((" << sdtype << ")0, C[" << Ci << CSTRIDE1 << "], beta>0);" << std::endl;
stream << "if(" << Ci << "< M) ";
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
}
if((n+1)%p_.simd_width==0)
if((n+1)%p_.simd_width==0){
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
else
}
else{
stream << "C += ldc;" << std::endl;
}
}

View File

@@ -119,15 +119,16 @@ void profiles::value_type::execute(controller<expressions_tuple> const & expr)
else if(predictor_.get())
{
std::vector<float> predictions = predictor_->predict(x);
do{
// do{
label = std::distance(predictions.begin(),std::max_element(predictions.begin(), predictions.end()));
predictions[label] = 0;
}while(templates_[label]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE);
// predictions[label] = 0;
// }while(templates_[label]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE);
}
//Execution
if(templates_[label]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE)
throw operation_not_supported_exception("Running this operation would require an overly large temporary.");
// std::cout << label << std::endl;
// if(templates_[label]->temporary_workspace(expr.x()) > MAX_TEMPORARY_WORKSPACE)
// throw operation_not_supported_exception("Running this operation would require an overly large temporary.");
return templates_[label]->enqueue(queue_, program, tools::to_string(label), *fallback_, expr);
}

View File

@@ -143,7 +143,7 @@ def main():
libraries=libraries)]
#External
extensions += [Extension('autotuning.external.sklearn._tree',
extensions += [Extension('external.sklearn._tree',
['external/sklearn/_tree.c'],
include_dirs = [numpy_include])]
@@ -155,7 +155,7 @@ def main():
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
license='MPL 2.0',
packages=['isaac','isaac.autotuning', 'isaac.autotuning.external', 'isaac.autotuning.external.deap', 'isaac.autotuning.external.deap.tools', 'isaac.autotuning.external.sklearn'],
packages=['isaac','isaac.external', 'isaac.external.sklearn'],
ext_package="isaac",
ext_modules=extensions,
cmdclass={'build_py': build_py, 'build_ext': build_ext_subclass},

View File

@@ -81,27 +81,32 @@ private:
void export_exceptions()
{
wrap::exception<isaac::operation_not_supported_exception>("OperationNotSupported", bp::init<std::string>())
.def("__str__", &isaac::operation_not_supported_exception::what)
;
wrap::exception<isaac::driver::ocl::exception::out_of_resources>("LaunchOutOfResources")
.def("__str__", &isaac::driver::ocl::exception::out_of_resources::what)
;
#define BIND_EXCEPTION(CPPNAME, PYTHONNAME) \
wrap::exception<isaac::CPPNAME>(PYTHONNAME, bp::init<std::string>())\
.def("__str__", &isaac::CPPNAME::what)
wrap::exception<isaac::driver::ocl::exception::mem_object_allocation_failure>("MemObjectAllocationFailure")
.def("__str__", &isaac::driver::ocl::exception::mem_object_allocation_failure::what)
;
BIND_EXCEPTION(operation_not_supported_exception, "OperationNotSupported");
wrap::exception<isaac::driver::ocl::exception::out_of_host_memory>("OutOfHostMemory")
.def("__str__", &isaac::driver::ocl::exception::out_of_host_memory::what)
;
//OCL
wrap::exception<isaac::driver::ocl::exception::base>("OclException", bp::no_init);
#define BIND_OCL_EXCEPTION(CPPNAME, PYTHONNAME) \
wrap::exception<isaac::driver::ocl::exception::CPPNAME, bp::bases<isaac::driver::ocl::exception::base> >(PYTHONNAME)\
.def("__str__", &isaac::driver::ocl::exception::CPPNAME::what)
wrap::exception<isaac::driver::ocl::exception::invalid_work_group_size>("InvalidWorkGroupSize")
.def("__str__", &isaac::driver::ocl::exception::invalid_work_group_size::what)
;
wrap::exception<isaac::driver::ocl::exception::invalid_value>("InvalidValue")
.def("__str__", &isaac::driver::ocl::exception::invalid_value::what)
;
BIND_OCL_EXCEPTION(out_of_resources, "OclLaunchOutOfResources");
BIND_OCL_EXCEPTION(mem_object_allocation_failure, "MemObjectAllocationFailure");
BIND_OCL_EXCEPTION(out_of_host_memory, "OutOfHostMemory");
BIND_OCL_EXCEPTION(invalid_work_group_size, "InvalidWorkGroupSize");
BIND_OCL_EXCEPTION(invalid_value, "InvalidValue");
//CUDA
wrap::exception<isaac::driver::cuda::exception::base>("CudaException", bp::no_init);
#define BIND_CUDA_EXCEPTION(CPPNAME, PYTHONNAME) \
wrap::exception<isaac::driver::cuda::exception::CPPNAME, bp::bases<isaac::driver::cuda::exception::base> >(PYTHONNAME)\
.def("__str__", &isaac::driver::cuda::exception::CPPNAME::what)
BIND_CUDA_EXCEPTION(launch_out_of_resources, "CudaLaunchOutOfResources");
}

View File

@@ -3,6 +3,7 @@
#include "common.hpp"
#include "isaac/array.h"
#include "isaac/wrap/clBLAS.h"
#include "isaac/driver/common.h"
namespace sc = isaac;
typedef isaac::int_t int_t;
@@ -45,6 +46,7 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
std::cout << std::endl;\
}
if(queue.device().backend()==sc::driver::OPENCL){
#define PREFIX "[C]"
RUN_TEST_VECTOR_AXPY("AXPY", cz[i] = a*cx[i] + cz[i], BLAS<T>::F(clblasSaxpy, clblasDaxpy)(N, a, CHANDLE(x), x.start()[0], x.stride()[0],
CHANDLE(z), z.start()[0], z.stride()[0],
@@ -56,9 +58,9 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
RUN_TEST_VECTOR_AXPY("SCAL", cz[i] = a*cz[i], BLAS<T>::F(clblasSscal, clblasDscal)(N, a, CHANDLE(z), z.start()[0], z.stride()[0],
1, &clqueue, 0, NULL, NULL));
#undef PREFIX
}
#define PREFIX "[C++]"
RUN_TEST_VECTOR_AXPY("z = 0", cz[i] = 0, z = zeros(N, 1, dtype, context))
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)

View File

@@ -52,7 +52,7 @@ void test_impl(T epsilon, simple_matrix_base<T> & cC, simple_matrix_base<T> cons
else\
std::cout << std::endl;
if(interf==clBLAS)
if(C.context().backend()==sc::driver::OPENCL && interf==clBLAS)
{
cl_command_queue clqueue = queue.handle().cl();

View File

@@ -24,20 +24,23 @@ def nrmse(y_ground, y):
def train(X, Y, profiles):
X = np.array(X)
Y = np.array(Y)
M = X.shape[0]
#Remove unused profiles
unused = np.where(np.bincount(np.argmax(Y, 1))==0)[0]
profiles = [x in profiles for ix,x in enumerate(profiles) if ix not in unused]
Y = np.delete(Y, np.where(np.bincount(np.argmax(Y, 1))==0), axis=1)
#Shuffle
p = np.random.permutation(X.shape[0])
M = X.shape[0]
X = X[p,:]
Y = Y[p,:]
#Train the.profile
cut = int(1.00*M)
CV = .1
XTr, YTr = X[:,:], Y[:,:]
XCv, YCv = X[:max(1,CV*M),:], Y[:max(1,CV*M),:]
cut = int(.7*M)
XTr, YTr = X[:cut,:], Y[:cut,:]
XCv, YCv = X[cut:,:], Y[cut:,:]
nrmses = {}
for N in range(1,min(M+1,20)):
for N in range(1,min(M+1,10)):
for depth in range(1,min(M+1,20)):
clf = RandomForestRegressor(N, max_depth=depth).fit(XTr, YTr)
t = np.argmax(clf.predict(XCv), axis = 1)

View File

@@ -13,12 +13,12 @@ from numpy import cumsum
import tools
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_CONTIGUOUS,
sc.templates.fetching_policy_type.FETCH_FROM_GLOBAL_STRIDED,
fetch_types = [sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL,
sc.templates.fetching_policy_type.FETCH_FROM_LOCAL]
to_catch = (sc.OperationNotSupported, sc.LaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue)
to_catch = (sc.OperationNotSupported, sc.OclLaunchOutOfResources, sc.CudaLaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue)
def exhaustive(template, sizes, context):
tree, _ = tools.tree_of(template, sizes, context)

View File

@@ -1,7 +1,7 @@
import random, argparse, json, os
from math import log, isinf
from itertools import chain, product
from numpy import argsort, argmax
from numpy import argsort, argmax, where, delete, bincount
from operator import mul
import isaac as sc
from isaac.external.sklearn.forest import RandomForestRegressor
@@ -9,6 +9,8 @@ import optimize, tools, model
from json import encoder
import json
to_catch = (sc.OperationNotSupported, sc.OclLaunchOutOfResources, sc.CudaLaunchOutOfResources, sc.MemObjectAllocationFailure, sc.InvalidWorkGroupSize, sc.OutOfHostMemory, sc.InvalidValue)
encoder.FLOAT_REPR = lambda o: format(o, '.2f')
encoder.separators = (',',':')
@@ -30,14 +32,14 @@ def do_tuning(device, operation, json_path):
sizes[sc.templates.axpy] = [(x,) for x in tools.expspace(1e3, 1e8, 4)]
sizes[sc.templates.gemv_n] = product(pow2range(4,17), pow2range(4,17))
sizes[sc.templates.gemv_t] = sizes[sc.templates.gemv_n]
sizes[sc.templates.gemm_nn] = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 15))
sizes[sc.templates.gemm_nn] = product(pow2range(5, 12), pow2range(5, 12), pow2range(5, 17))
sizes[sc.templates.gemm_tn] = sizes[sc.templates.gemm_nn]
sizes[sc.templates.gemm_nt] = sizes[sc.templates.gemm_nn]
sizes[sc.templates.gemm_tt] = sizes[sc.templates.gemm_nn]
#Quick tuning - AlexNet sizes + Intuition
quick_tuning = True
quick_tuning = False
if quick_tuning:
sizes[sc.templates.ger] = [(1536,1536)]
@@ -48,12 +50,14 @@ def do_tuning(device, operation, json_path):
(729,256),
(3025,96)]
sizes[sc.templates.gemm_nn] = [(576,20,25),(64,50,500),(3025,96,363),
sizes[sc.templates.gemm_nn] = [(3025,96,363),
(729,128,1200),
(169,384,2304),
(169,192,1728),
(169,128,1728)]
sizes[sc.templates.gemm_nt] = [(169,1728,128),
sizes[sc.templates.gemm_nt] = [(1536,1536,1536),
(169,1728,128),
(169,1728,192),
(169,2304,384),
(729,1200,128)]
@@ -121,7 +125,7 @@ def do_tuning(device, operation, json_path):
for b in best:
try:
perf += [performance(x, tools.benchmark(operation, profiles[b], tree))]
except (sc.OperationNotSupported, sc.LaunchOutOfResources, sc.MemObjectAllocationFailure):
except to_catch:
pass
predicted = profiles[best[argmax(perf)]]
tune = not optimize.is_local_optimum(predicted, operation, x, context)
@@ -138,7 +142,7 @@ def do_tuning(device, operation, json_path):
try:
time = tools.benchmark(operation, new, _tree)
perf = performance(xx, time)
except (sc.OperationNotSupported, sc.LaunchOutOfResources, sc.MemObjectAllocationFailure):
except to_catch:
perf = 0
yy.append(0 if isinf(perf) else perf)
#Update dataset
@@ -147,7 +151,7 @@ def do_tuning(device, operation, json_path):
for ip, p in enumerate(profiles):
try:
perf = 0 if fastest and ip < nparams and predperf[ip]/fastest < .1 else performance(x,tools.benchmark(operation, p, tree))
except (sc.OperationNotSupported, sc.LaunchOutOfResources, sc.MemObjectAllocationFailure):
except to_catch:
perf = 0
y.append(0 if isinf(perf) else perf)
X.append(x)
@@ -158,6 +162,10 @@ def do_tuning(device, operation, json_path):
csv.writer(f).writerows(data)
unused = where(bincount(argmax(Y, 1))==0)[0]
profiles = [x for ix,x in enumerate(profiles) if ix not in unused]
Y = delete(Y, unused, axis=1)
#Export to JSON
json_path = tools.sanitize(device.name) + '.json' if not json_path else json_path
if os.path.isfile(json_path):