[general] a bunch of fixes in anticipation of proper triton vs cudnn

benchmarks

* DNN: Added partial auto-tuning mode and skeleton for heuristics
* Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
Philippe Tillet
2019-07-21 20:17:56 -07:00
parent b1d81a5802
commit ead368d1ed
10 changed files with 221 additions and 147 deletions

View File

@@ -11,14 +11,14 @@
void extract_shapes(const torch::Tensor &x,
int64_t &C, int64_t &H, int64_t &W, int64_t &B,
triton::dnn::shift::layout_t layout) {
if(layout == triton::dnn::shift::CHWN){
triton::dnn::layout_t layout) {
if(layout == triton::dnn::CHWN){
C = x.size(0);
H = x.size(1);
W = x.size(2);
B = x.size(3);
}
else if(layout == triton::dnn::shift::NCHW){
else if(layout == triton::dnn::NCHW){
B = x.size(0);
C = x.size(1);
H = x.size(2);
@@ -29,14 +29,14 @@ void extract_shapes(const torch::Tensor &x,
}
}
static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW;
static const triton::dnn::layout_t layout = triton::dnn::NCHW;
torch::Tensor shift_common(
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
int32_t T, int32_t R, int32_t S, int32_t F,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w,
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
triton::dnn::op_t op, triton::dnn::layout_t layout,
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
bool autotune = false
) {
@@ -59,7 +59,7 @@ torch::Tensor shift_common(
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
stride_h, stride_w,
shift_h, shift_w, dtype, dtype,
ty, has_bias, layout);
op, has_bias, layout);
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
@@ -74,8 +74,9 @@ torch::Tensor shift_common(
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
std::cout << B << ", " << C << ", " << H << ", " << W << ", " << T << ", " << R << ", " << S << ", " << F << ", " << stride_h << ", " << stride_w << ", " << op << ", " << layout << std::endl;
// Enqueue
shift.enqueue(&stream, {&a, &b, &c}, true);
shift.enqueue(&stream, {&a, &b, &c}, triton::dnn::NO_TUNING);
return torchc;
}
@@ -99,7 +100,7 @@ torch::Tensor shift_y(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::FPROP, layout, x, w, bias);
triton::dnn::FPROP, layout, x, w, bias);
}
torch::Tensor shift_dx(
@@ -127,7 +128,7 @@ torch::Tensor shift_dx(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::BPROP, layout, dy, w, bias);
triton::dnn::BPROP, layout, dy, w, bias);
}
torch::Tensor shift_dw(
@@ -155,7 +156,7 @@ torch::Tensor shift_dw(
// run
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::WGRAD, layout, dy, x, bias);
triton::dnn::WGRAD, layout, dy, x, bias);
}
static auto registry =