testing a simple shiftnet

This commit is contained in:
Philippe Tillet
2019-07-10 13:33:08 -07:00
parent 3b89bc8463
commit f665c742f9
6 changed files with 261 additions and 31 deletions

View File

@@ -4,11 +4,18 @@
#include "triton/driver/stream.h"
#include "triton/dnn/batchnorm.h"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor>
batchnorm_ymv(const torch::Tensor fw_x,
const torch::Tensor fw_g,
const torch::Tensor fw_b,
float eps) {
double eps) {
CHECK_INPUT(fw_x);
CHECK_INPUT(fw_g);
CHECK_INPUT(fw_b);
// Wrap CUDA handles
c10::DeviceIndex device = fw_x.storage().device().index();
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
@@ -30,8 +37,9 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
// create template
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps);
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
stream.synchronize();
return {fw_y, fw_m, fw_v};
}
@@ -41,7 +49,12 @@ std::vector<torch::Tensor>
const torch::Tensor fw_g,
const torch::Tensor fw_m,
const torch::Tensor fw_v,
float eps) {
double eps) {
CHECK_INPUT(fw_dy);
CHECK_INPUT(fw_x);
CHECK_INPUT(fw_g);
CHECK_INPUT(fw_m);
CHECK_INPUT(fw_v);
// Wrap CUDA handles
c10::DeviceIndex device = fw_x.storage().device().index();
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
@@ -68,4 +81,10 @@ std::vector<torch::Tensor>
// create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
stream.synchronize();
return {fw_dx, fw_dg, fw_db};
}
static auto registry =
torch::jit::RegisterOperators("triton::batchnorm_ymv", &batchnorm_ymv)
.op("triton::batchnorm_dxdgdb", &batchnorm_dxdgdb);