Code Quality: some cleanups

This commit is contained in:
Philippe Tillet
2017-05-11 16:11:40 -07:00
parent 3ff543151c
commit 9212ab5a3d
40 changed files with 112 additions and 60 deletions

View File

@@ -172,16 +172,14 @@ std::vector<param_t> Profile::predict(driver::Device const & device, std::vector
ConvProfile::ConvProfile(u_char* data): Profile(data, 8){}
templates::Conv ConvProfile::predict(driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
templates::Conv ConvProfile::predict(driver::Stream& stream, driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
param_t pad_h, param_t pad_w, param_t stride_h, param_t stride_w)
{
std::vector<param_t> shapes{dtype, N, K, P, Q, C, R, S};
driver::Context ctx(device);
driver::Stream stream(ctx);
driver::Buffer O(ctx, N*K*P*Q*size_of(dtype));
driver::Buffer I(ctx, C*H*W*N*size_of(dtype));
driver::Buffer F(ctx, C*K*R*S*size_of(dtype));
driver::Buffer O(stream.context(), N*K*P*Q*size_of(dtype));
driver::Buffer I(stream.context(), C*H*W*N*size_of(dtype));
driver::Buffer F(stream.context(), C*K*R*S*size_of(dtype));
scalar alpha(1., dtype);
scalar beta(0., dtype);
std::function<double(std::vector<param_t> const&)> benchmark = [&](std::vector<param_t> const& x){
@@ -201,16 +199,14 @@ templates::Conv ConvProfile::predict(driver::Device const & device, DType dtype,
GEMMProfile::GEMMProfile(u_char* data): Profile(data, 6){}
templates::GEMM GEMMProfile::predict(driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
templates::GEMM GEMMProfile::predict(driver::Stream& stream, driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
param_t offa, param_t lda, param_t offb, param_t ldb, param_t offc, param_t ldc)
{
std::vector<param_t> shapes{dtype, AT, BT, M, N, K};
driver::Context ctx(device);
driver::Stream stream(ctx);
driver::Buffer C(ctx, M*N*size_of(dtype));
driver::Buffer A(ctx, M*K*size_of(dtype));
driver::Buffer B(ctx, K*N*size_of(dtype));
driver::Buffer C(stream.context(), M*N*size_of(dtype));
driver::Buffer A(stream.context(), M*K*size_of(dtype));
driver::Buffer B(stream.context(), K*N*size_of(dtype));
scalar alpha(1., dtype);
scalar beta(0., dtype);
std::function<double(std::vector<param_t> const&)> benchmark = [&](std::vector<param_t> const& x)