[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/shift.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
@@ -120,3 +117,8 @@ torch::Tensor shift_dw(
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::jit::RegisterOperators("triton::shift_conv_y", &shift_y)
|
||||
.op("triton::shift_conv_dx", &shift_dx)
|
||||
.op("triton::shift_conv_dw", &shift_dw);
|
||||
|
Reference in New Issue
Block a user