[general] a bunch of fixes in anticipation of proper triton vs cudnn
benchmarks * DNN: Added partial auto-tuning mode and skeleton for heuristics * Examples: Moduralized benchmarking and now evaluating ResNet-18 shapes
This commit is contained in:
@@ -11,14 +11,14 @@
|
||||
|
||||
void extract_shapes(const torch::Tensor &x,
|
||||
int64_t &C, int64_t &H, int64_t &W, int64_t &B,
|
||||
triton::dnn::shift::layout_t layout) {
|
||||
if(layout == triton::dnn::shift::CHWN){
|
||||
triton::dnn::layout_t layout) {
|
||||
if(layout == triton::dnn::CHWN){
|
||||
C = x.size(0);
|
||||
H = x.size(1);
|
||||
W = x.size(2);
|
||||
B = x.size(3);
|
||||
}
|
||||
else if(layout == triton::dnn::shift::NCHW){
|
||||
else if(layout == triton::dnn::NCHW){
|
||||
B = x.size(0);
|
||||
C = x.size(1);
|
||||
H = x.size(2);
|
||||
@@ -29,14 +29,14 @@ void extract_shapes(const torch::Tensor &x,
|
||||
}
|
||||
}
|
||||
|
||||
static const triton::dnn::shift::layout_t layout = triton::dnn::shift::NCHW;
|
||||
static const triton::dnn::layout_t layout = triton::dnn::NCHW;
|
||||
|
||||
torch::Tensor shift_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t F,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w,
|
||||
triton::dnn::shift::op_t ty, triton::dnn::shift::layout_t layout,
|
||||
triton::dnn::op_t op, triton::dnn::layout_t layout,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
@@ -59,7 +59,7 @@ torch::Tensor shift_common(
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, dtype, dtype,
|
||||
ty, has_bias, layout);
|
||||
op, has_bias, layout);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
@@ -74,8 +74,9 @@ torch::Tensor shift_common(
|
||||
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
std::cout << B << ", " << C << ", " << H << ", " << W << ", " << T << ", " << R << ", " << S << ", " << F << ", " << stride_h << ", " << stride_w << ", " << op << ", " << layout << std::endl;
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c}, true);
|
||||
shift.enqueue(&stream, {&a, &b, &c}, triton::dnn::NO_TUNING);
|
||||
return torchc;
|
||||
}
|
||||
|
||||
@@ -99,7 +100,7 @@ torch::Tensor shift_y(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::FPROP, layout, x, w, bias);
|
||||
triton::dnn::FPROP, layout, x, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dx(
|
||||
@@ -127,7 +128,7 @@ torch::Tensor shift_dx(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::BPROP, layout, dy, w, bias);
|
||||
triton::dnn::BPROP, layout, dy, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dw(
|
||||
@@ -155,7 +156,7 @@ torch::Tensor shift_dw(
|
||||
// run
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::WGRAD, layout, dy, x, bias);
|
||||
triton::dnn::WGRAD, layout, dy, x, bias);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
|
Reference in New Issue
Block a user