testing a simple shiftnet
This commit is contained in:
@@ -4,11 +4,18 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
batchnorm_ymv(const torch::Tensor fw_x,
|
||||
const torch::Tensor fw_g,
|
||||
const torch::Tensor fw_b,
|
||||
float eps) {
|
||||
double eps) {
|
||||
CHECK_INPUT(fw_x);
|
||||
CHECK_INPUT(fw_g);
|
||||
CHECK_INPUT(fw_b);
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
@@ -30,8 +37,9 @@ std::vector<torch::Tensor>
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
||||
// create template
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
||||
stream.synchronize();
|
||||
return {fw_y, fw_m, fw_v};
|
||||
}
|
||||
|
||||
@@ -41,7 +49,12 @@ std::vector<torch::Tensor>
|
||||
const torch::Tensor fw_g,
|
||||
const torch::Tensor fw_m,
|
||||
const torch::Tensor fw_v,
|
||||
float eps) {
|
||||
double eps) {
|
||||
CHECK_INPUT(fw_dy);
|
||||
CHECK_INPUT(fw_x);
|
||||
CHECK_INPUT(fw_g);
|
||||
CHECK_INPUT(fw_m);
|
||||
CHECK_INPUT(fw_v);
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
@@ -68,4 +81,10 @@ std::vector<torch::Tensor>
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
stream.synchronize();
|
||||
return {fw_dx, fw_dg, fw_db};
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::jit::RegisterOperators("triton::batchnorm_ymv", &batchnorm_ymv)
|
||||
.op("triton::batchnorm_dxdgdb", &batchnorm_dxdgdb);
|
||||
|
Reference in New Issue
Block a user