[examples/python/pytorch] added skeleton of wrapper for shift-conv and batch-norm
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
Reference in New Issue
Block a user