[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm

This commit is contained in:
Philippe Tillet
2019-07-09 21:54:37 -07:00
parent 63b249c1d6
commit 3b89bc8463
8 changed files with 63 additions and 151 deletions

View File

@@ -1,12 +1,9 @@
#include <vector>
#include <sstream>
#include <torch/torch.h>
#include <torch/script.h>
#include "ATen/cuda/CUDAContext.h"
#include "triton/runtime/jit.h"
#include "triton/driver/stream.h"
#include "triton/dnn/shift.h"
#include "triton/tools/bench.hpp"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
@@ -120,3 +117,8 @@ torch::Tensor shift_dw(
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
}
static auto registry =
torch::jit::RegisterOperators("triton::shift_conv_y", &shift_y)
.op("triton::shift_conv_dx", &shift_dx)
.op("triton::shift_conv_dw", &shift_dw);