Code Quality: some cleanups

This commit is contained in:
Philippe Tillet
2017-05-11 16:11:40 -07:00
parent 3ff543151c
commit 9212ab5a3d
40 changed files with 112 additions and 60 deletions

View File

@@ -141,8 +141,8 @@ int main(int argc, char* argv[])
drv::Buffer F(ctx, K*C*R*S*dtsize);
std::vector<double> times;
times.push_back(bench([&](){ sc::CONV(device, stream, dtype, N, K, P, Q, C, R, S, H, W, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O); }, [&](){ stream.synchronize(); }, device));
times.push_back(bench([&](){ sc::driver::cudnnConv(dtype, ctx, stream, H, W, N, K, P, Q, C, R, S, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O); }, [&](){ stream.synchronize(); }, device));
// times.push_back(bench([&](){ sc::CONV(device, stream, dtype, N, K, P, Q, C, R, S, H, W, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O); }, [&](){ stream.synchronize(); }, device));
times.push_back(bench([&](){ sc::driver::cudnnConv(dtype, stream, H, W, N, K, P, Q, C, R, S, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O); }, [&](){ stream.synchronize(); }, device));
speedup.push_back(times[1]/times[0]);
print_results(times, {str(N), str(K), str(P), str(Q), str(C), str(R), str(S)}, [&](double tsec){ return sc::templates::Conv::tflops(P,Q,K,N,C,R,S,tsec);});
}
@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
std::vector<double> times;
times.push_back(bench([&](){ sc::GEMM(device, stream, dtype, AT, BT, M, N, K, 0, lda, 0, ldb, 0, ldc, alpha, A, B, beta, C); }, [&](){ stream.synchronize(); }, device));
times.push_back(bench([&](){ sc::driver::cublasGemm(dtype, ctx, stream, cuAT, cuBT, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); }, [&](){ stream.synchronize(); }, device));
times.push_back(bench([&](){ sc::driver::cublasGemm(dtype, stream, cuAT, cuBT, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); }, [&](){ stream.synchronize(); }, device));
speedup.push_back(times[1]/times[0]);
print_results(times, {str(AT), str(BT), str(M), str(N), str(K)}, [&](double tsec){ return sc::templates::GEMM::tflops(M, N, K, tsec);});
}

View File

@@ -45,7 +45,7 @@ void GEMM(driver::Device const & device, driver::Stream & stream,
static std::function<value_type()> compile = [&](){
//Fetch profile
runtime::GEMMProfile* profile = (runtime::GEMMProfile*)runtime::database.at({device.architecture(), runtime::GEMM}).get();
templates::GEMM generator = profile->predict(device, dtype, AT, BT, M, N, K, offa, lda, offb, ldb, offc, ldc);
templates::GEMM generator = profile->predict(stream, device, dtype, AT, BT, M, N, K, offa, lda, offb, ldb, offc, ldc);
//Execute
std::string src = generator.dump(device, "gemm");
driver::Module module(stream.context(), src);
@@ -69,14 +69,13 @@ void CONV(driver::Device const & device, driver::Stream & stream,
static std::function<value_type()> compile = [&](){
//Fetch profile
runtime::ConvProfile* profile = (runtime::ConvProfile*)runtime::database.at({device.architecture(), runtime::CONV}).get();
templates::Conv generator = profile->predict(device, dtype, C, H, W, N, K, P, Q, R, S, pad_h, pad_w, stride_h, stride_w);
templates::Conv generator = profile->predict(stream, device, dtype, C, H, W, N, K, P, Q, R, S, pad_h, pad_w, stride_h, stride_w);
//Execute
std::string src = generator.dump(device, "fconv");
std::string src = generator.dump(device, "conv");
driver::Module module(stream.context(), src);
return value_type(std::make_shared<templates::Conv>(generator), std::make_shared<driver::Kernel>(module, "fconv"));
return value_type(std::make_shared<templates::Conv>(generator), std::make_shared<driver::Kernel>(module, "conv"));
};
static cpp::CachedMap<key_type, value_type> cache(compile);
//Retrieve profile/kernel and execute
value_type const & value = cache.get(key_type(stream, dtype, N, K, P, Q, C, R, S, pad_h, pad_w, stride_h, stride_w));
value.first->enqueue(*value.second, stream, alpha, I, F, beta, O);

View File

@@ -24,6 +24,7 @@
#define ISAAC_DRIVER_BUFFER_H
#include "isaac/driver/handle.h"
#include "isaac/driver/context.h"
namespace isaac
{
@@ -41,6 +42,7 @@ public:
Handle<CUdeviceptr> const & cu() const;
private:
Context context_;
Handle<CUdeviceptr> cu_;
size_t size_;
};

View File

@@ -52,6 +52,14 @@ private:
std::string cache_path_;
};
class ContextSwitcher{
public:
ContextSwitcher(Context const & ctx);
~ContextSwitcher();
private:
Context const & ctx_;
};
}
}

View File

@@ -40,19 +40,20 @@ template<typename... Args> void cublasGemm_impl(double, Args... args){ driver::d
template<class cuType>
inline void cublasGemm_dispatch(Context const & ctx, Stream& queue, char AT, char BT, int32_t M, int32_t N, int32_t K, void* alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, void* beta, Buffer& C, int32_t ldc){
inline void cublasGemm_dispatch(Stream& stream, char AT, char BT, int32_t M, int32_t N, int32_t K, void* alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, void* beta, Buffer& C, int32_t ldc){
auto cu_trans = [](char xt) { return (xt=='N')?CUBLAS_OP_N:CUBLAS_OP_T; };
cublasHandle_t handle = dispatch::cublasHandle(ctx);
dispatch::cublasSetStream_v2(handle, (CUstream)queue);
cublasHandle_t handle = dispatch::cublasHandle(stream.context());
dispatch::cublasSetStream_v2(handle, (CUstream)stream);
CUdeviceptr cuA = A, cuB = B, cuC = C;
cublasGemm_impl(cuType(), handle, cu_trans(AT), cu_trans(BT), M, N, K, (cuType*)alpha, (const cuType*)cuA, lda, (const cuType*)cuB, ldb, (cuType*)beta, (cuType*)cuC, ldc);
}
inline void cublasGemm(DType dtype, Context const & ctx, Stream& queue, char AT, char BT, int32_t M, int32_t N, int32_t K, scalar alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, scalar beta, Buffer& C, int32_t ldc){
inline void cublasGemm(DType dtype, Stream& stream, char AT, char BT, int32_t M, int32_t N, int32_t K, scalar alpha, Buffer const & A, int32_t lda, Buffer const & B, int32_t ldb, scalar beta, Buffer& C, int32_t ldc){
ContextSwitcher ctx_switch(stream.context());
switch(dtype){
case HALF_TYPE: return cublasGemm_dispatch<half>(ctx, queue, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
case FLOAT_TYPE: return cublasGemm_dispatch<float>(ctx, queue, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
case DOUBLE_TYPE: return cublasGemm_dispatch<double>(ctx, queue, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
case HALF_TYPE: return cublasGemm_dispatch<half>(stream, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
case FLOAT_TYPE: return cublasGemm_dispatch<float>(stream, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
case DOUBLE_TYPE: return cublasGemm_dispatch<double>(stream, AT, BT, M, N, K, alpha.data(), A, lda, B, ldb, beta.data(), C, ldc);
default: throw;
}
}
@@ -66,12 +67,19 @@ inline cudnnDataType_t cudnnDtype(DType dtype){
throw;
}
inline void cudnnConv(DType dtype, Context const & ctx, Stream& queue, int32_t H, int32_t W, int32_t N, int32_t K, int32_t P, int32_t Q, int32_t C, int32_t R, int32_t S,
inline void cudnnConv(DType dtype, Stream& stream, int32_t H, int32_t W, int32_t N, int32_t K, int32_t P, int32_t Q, int32_t C, int32_t R, int32_t S,
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w, scalar alpha, Buffer const & I, Buffer const & F, scalar beta, Buffer const & O){
driver::Context const & ctx = stream.context();
// ContextSwitcher switch_ctx(ctx);
// CUcontext cuctx;
dispatch::cuCtxSetCurrent(ctx);
// std::cout << cuctx << " " << CUcontext(ctx) << std::endl;
cudnnHandle_t handle = dispatch::cudnnHandle(ctx);
cudnnDataType_t cutype = cudnnDtype(dtype);
dispatch::cudnnSetStream(handle, (CUstream)queue);
dispatch::cudnnSetStream(handle, (CUstream)stream);
cudnnTensorDescriptor_t tO, tI;
cudnnFilterDescriptor_t tF;
cudnnConvolutionDescriptor_t conv;

View File

@@ -27,11 +27,11 @@
#include <dlfcn.h>
//CUDA Backend
#include "isaac/driver/external/CUDA/cuda.h"
#include "isaac/driver/external/CUDA/nvrtc.h"
#include "isaac/driver/external/CUDA/cublas.h"
#include "isaac/driver/external/CUDA/cudnn.h"
#include "isaac/driver/external/CUDA/nvml.h"
#include "isaac/external/CUDA/cuda.h"
#include "isaac/external/CUDA/nvrtc.h"
#include "isaac/external/CUDA/cublas.h"
#include "isaac/external/CUDA/cudnn.h"
#include "isaac/external/CUDA/nvml.h"
//Exceptions
#include <iostream>
@@ -86,6 +86,8 @@ public:
//CUDA
static CUresult cuCtxGetCurrent(CUcontext *pctx);
static CUresult cuCtxSetCurrent(CUcontext ctx);
static CUresult cuCtxDestroy_v2(CUcontext ctx);
static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
static CUresult cuDeviceGet(CUdevice *device, int ordinal);
@@ -164,6 +166,7 @@ private:
//CUDA
static void* cuCtxGetCurrent_;
static void* cuCtxSetCurrent_;
static void* cuCtxDestroy_v2_;
static void* cuEventCreate_;
static void* cuDeviceGet_;

View File

@@ -105,14 +105,14 @@ private:
class ConvProfile: public Profile{
public:
ConvProfile(u_char* data);
templates::Conv predict(driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
templates::Conv predict(driver::Stream& stream, driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
param_t pad_h, param_t pad_w, param_t stride_h, param_t stride_w);
};
class GEMMProfile: public Profile{
public:
GEMMProfile(u_char* data);
templates::GEMM predict(driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
templates::GEMM predict(driver::Stream& stream, driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
param_t offa, param_t lda, param_t offb, param_t ldb, param_t offc, param_t ldc);
};

View File

@@ -34,11 +34,17 @@ namespace driver
{
Buffer::Buffer(Context const & /*context*/, size_t size) : size_(size)
{ dispatch::cuMemAlloc(&*cu_, size); }
Buffer::Buffer(Context const & context, size_t size) : context_(context), size_(size)
{
ContextSwitcher ctx_switch(context_);
dispatch::cuMemAlloc(&*cu_, size);
}
void Buffer::set_zero(Stream const & queue)
{ dispatch::cuMemsetD8Async(*cu_, 0, size_, queue); }
{
ContextSwitcher ctx_switch(context_);
dispatch::cuMemsetD8Async(*cu_, 0, size_, queue);
}
Handle<CUdeviceptr> const & Buffer::cu() const
{ return cu_; }

View File

@@ -21,6 +21,7 @@
*/
#include <iostream>
#include <cassert>
#include "isaac/driver/context.h"
#include "isaac/driver/module.h"
@@ -63,7 +64,10 @@ Context::Context(CUcontext context, bool take_ownership): cu_(context, take_owne
{ }
Context::Context(Device const & device): device_(device), cache_path_(get_cache_path())
{ dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, (CUdevice)device); }
{
dispatch::cuCtxCreate(&*cu_, CU_CTX_SCHED_AUTO, (CUdevice)device);
dispatch::cuCtxPopCurrent_v2(NULL);
}
Device const & Context::device() const
{ return device_; }
@@ -74,5 +78,18 @@ std::string const & Context::cache_path() const
Handle<CUcontext> const & Context::cu() const
{ return cu_; }
/* Context Switcher */
ContextSwitcher::ContextSwitcher(Context const & ctx): ctx_(ctx)
{ dispatch::cuCtxPushCurrent_v2(ctx_); }
ContextSwitcher::~ContextSwitcher()
{
CUcontext tmp;
dispatch::cuCtxPopCurrent_v2(&tmp);
assert(tmp==(CUcontext)ctx_ && "Switching back to invalid context!");
}
}
}

View File

@@ -176,6 +176,7 @@ CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
CUDA_DEFINE1(CUresult, cuCtxGetCurrent, CUcontext*)
CUDA_DEFINE1(CUresult, cuCtxSetCurrent, CUcontext)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
CUDA_DEFINE1(CUresult, cuCtxPushCurrent_v2, CUcontext)
CUDA_DEFINE1(CUresult, cuCtxPopCurrent_v2, CUcontext*)
@@ -260,6 +261,7 @@ void* dispatch::cudnn_;
//CUDA
void* dispatch::cuCtxGetCurrent_;
void* dispatch::cuCtxSetCurrent_;
void* dispatch::cuCtxDestroy_v2_;
void* dispatch::cuEventCreate_;
void* dispatch::cuDeviceGet_;

View File

@@ -48,16 +48,19 @@ CUjit_target_enum cutarget(Device::Architecture arch){
}
Module::Module(Context const & context, std::string const & source, bool is_ir) : context_(context), source_(source){
ContextSwitcher ctx_switch(context_);
//PTX passed directly
if(is_ir){
CUjit_option opt[] = {CU_JIT_TARGET, CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
unsigned int errbufsize = 8096;
std::string errbuf(errbufsize, 0);
CUjit_target_enum target = cutarget(context_.device().architecture());
void* optval[] = {reinterpret_cast<void*>(target), reinterpret_cast<void*>(errbufsize), (void*)errbuf.data()};
//CUjit_target_enum target = cutarget(context.device().architecture());
void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()};
try{
dispatch::cuModuleLoadDataEx(&*cu_, source.data(), 3, opt, optval);
dispatch::cuModuleLoadDataEx(&*cu_, source.data(), 2, opt, optval);
}catch(exception::cuda::base const &){
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
throw;
}

View File

@@ -42,15 +42,22 @@ Stream::Stream(CUstream stream, bool take_ownership): cu_(stream, take_ownership
{}
Stream::Stream(Context const & context): context_(context), cu_(CUstream(), true)
{ dispatch::cuStreamCreate(&*cu_, 0); }
{
ContextSwitcher ctx_switch(context_);
dispatch::cuStreamCreate(&*cu_, 0);
}
void Stream::synchronize()
{ dispatch::cuStreamSynchronize(*cu_); }
{
ContextSwitcher ctx_switch(context_);
dispatch::cuStreamSynchronize(*cu_);
}
Context const & Stream::context() const
{ return context_; }
void Stream::enqueue(Kernel const & kernel, std::array<size_t, 3> grid, std::array<size_t, 3> block, std::vector<Event> const *, Event* event){
ContextSwitcher ctx_switch(context_);
if(event)
dispatch::cuEventRecord(((cu_event_t)*event).first, *cu_);
dispatch::cuLaunchKernel(kernel, grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)kernel.cu_params(), NULL);

View File

@@ -172,16 +172,14 @@ std::vector<param_t> Profile::predict(driver::Device const & device, std::vector
ConvProfile::ConvProfile(u_char* data): Profile(data, 8){}
templates::Conv ConvProfile::predict(driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
templates::Conv ConvProfile::predict(driver::Stream& stream, driver::Device const & device, DType dtype, param_t C, param_t H, param_t W, param_t N, param_t K, param_t P, param_t Q, param_t R, param_t S,
param_t pad_h, param_t pad_w, param_t stride_h, param_t stride_w)
{
std::vector<param_t> shapes{dtype, N, K, P, Q, C, R, S};
driver::Context ctx(device);
driver::Stream stream(ctx);
driver::Buffer O(ctx, N*K*P*Q*size_of(dtype));
driver::Buffer I(ctx, C*H*W*N*size_of(dtype));
driver::Buffer F(ctx, C*K*R*S*size_of(dtype));
driver::Buffer O(stream.context(), N*K*P*Q*size_of(dtype));
driver::Buffer I(stream.context(), C*H*W*N*size_of(dtype));
driver::Buffer F(stream.context(), C*K*R*S*size_of(dtype));
scalar alpha(1., dtype);
scalar beta(0., dtype);
std::function<double(std::vector<param_t> const&)> benchmark = [&](std::vector<param_t> const& x){
@@ -201,16 +199,14 @@ templates::Conv ConvProfile::predict(driver::Device const & device, DType dtype,
GEMMProfile::GEMMProfile(u_char* data): Profile(data, 6){}
templates::GEMM GEMMProfile::predict(driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
templates::GEMM GEMMProfile::predict(driver::Stream& stream, driver::Device const & device, DType dtype, IsaacOperation_t AT, IsaacOperation_t BT, param_t M, param_t N, param_t K,
param_t offa, param_t lda, param_t offb, param_t ldb, param_t offc, param_t ldc)
{
std::vector<param_t> shapes{dtype, AT, BT, M, N, K};
driver::Context ctx(device);
driver::Stream stream(ctx);
driver::Buffer C(ctx, M*N*size_of(dtype));
driver::Buffer A(ctx, M*K*size_of(dtype));
driver::Buffer B(ctx, K*N*size_of(dtype));
driver::Buffer C(stream.context(), M*N*size_of(dtype));
driver::Buffer A(stream.context(), M*K*size_of(dtype));
driver::Buffer B(stream.context(), K*N*size_of(dtype));
scalar alpha(1., dtype);
scalar beta(0., dtype);
std::function<double(std::vector<param_t> const&)> benchmark = [&](std::vector<param_t> const& x)

View File

@@ -841,6 +841,7 @@ std::string Conv::dump(drv::Device const & device, std::string const & name){
inc_k += step_k;
}
iss << "}" << std::endl;
// std::cout << iss.str() << std::endl;
return iss.str();
}

View File

@@ -88,7 +88,7 @@ void do_test_impl(sc::driver::Context const & ctx, size_t N, size_t K, size_t H,
stream.write(O, true, 0, iO.size()*dtsize, iO.data());
stream.write(I, true, 0, iI.size()*dtsize, iI_cudnn.data());
stream.write(F, true, 0, iF.size()*dtsize, iF_cudnn.data());
sc::driver::cudnnConv(dtype, ctx, stream, H, W, N, K, P, Q, C, R, S, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O);
sc::driver::cudnnConv(dtype, stream, H, W, N, K, P, Q, C, R, S, pad_h, pad_w, stride_h, stride_w, alpha, I, F, beta, O);
std::vector<DTYPE> rO_cudnn(iO.size());
std::vector<DTYPE> rO(iO.size());
stream.read(O, true, 0, rO_cudnn.size()*dtsize, (void*)rO_cudnn.data());

View File

@@ -52,25 +52,25 @@ void do_test(sc::driver::Context const & ctx, sc::IsaacOperation_t AT, sc::Isaac
for(size_t i = 0; i < iA.size(); ++i) iA[i] = (float)rand()/RAND_MAX;
for(size_t i = 0; i < iB.size(); ++i) iB[i] = (float)rand()/RAND_MAX;
drv::Stream queue(ctx);
queue.write(C, true, 0, M*N*dtsize, iC.data());
queue.write(A, true, 0, M*K*dtsize, iA.data());
queue.write(B, true, 0, K*N*dtsize, iB.data());
drv::Stream stream(ctx);
stream.write(C, true, 0, M*N*dtsize, iC.data());
stream.write(A, true, 0, M*K*dtsize, iA.data());
stream.write(B, true, 0, K*N*dtsize, iB.data());
//Ground result (cuBLAS)
char cuAT = (AT==sc::ISAAC_OP_T)?'T':'N';
char cuBT = (BT==sc::ISAAC_OP_T)?'T':'N';
sc::driver::cublasGemm(dtype, ctx, queue, cuAT, cuBT, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
sc::driver::cublasGemm(dtype, stream, cuAT, cuBT, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
std::vector<DTYPE> rC(M*N);
queue.read(C, true, 0, M*N*dtsize, (void*)rC.data());
queue.write(C, true, 0, M*N*dtsize, iC.data());
stream.read(C, true, 0, M*N*dtsize, (void*)rC.data());
stream.write(C, true, 0, M*N*dtsize, iC.data());
//ISAAC result
std::vector<DTYPE> hC(M*N);
//Test selected profile
sc::GEMM(ctx.device(), queue, dtype, AT, BT, M, N, K, offa, lda, offb, ldb, offc, ldc, alpha, A, B, beta, C);
queue.read(C, true, 0, M*N*dtsize, (void*)hC.data());
sc::GEMM(ctx.device(), stream, dtype, AT, BT, M, N, K, offa, lda, offb, ldb, offc, ldc, alpha, A, B, beta, C);
stream.read(C, true, 0, M*N*dtsize, (void*)hC.data());
if(!is_correct(hC, rC, max_rounding_error(DTYPE(K))))
exit(EXIT_FAILURE);
@@ -93,11 +93,11 @@ void do_test(sc::driver::Context const & ctx, sc::IsaacOperation_t AT, sc::Isaac
drv::Kernel kernel(program, "gemm");
//Launch
gemm.enqueue(kernel, queue, alpha, A, B, beta, C);
queue.synchronize();
gemm.enqueue(kernel, stream, alpha, A, B, beta, C);
stream.synchronize();
//Test
queue.read(C, true, 0, M*N*dtsize, (void*)hC.data());
stream.read(C, true, 0, M*N*dtsize, (void*)hC.data());
size_t depth = x[11]*x[12]*x[13];
double eps = max_rounding_error(DTYPE(K/depth))*depth;
if(!is_correct(hC, rC, eps))