[examples/python/pytorch] added batchnorm cpp extension
This commit is contained in:
@@ -5,6 +5,6 @@ if(${TORCH_FOUND})
|
||||
include_directories("${CUDA_HOME}/include")
|
||||
link_directories(${TORCH_LIBRARY_DIRS})
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
add_library(torch_triton SHARED conv.cpp shift.cpp)
|
||||
add_library(torch_triton SHARED conv.cpp shift.cpp batchnorm.cpp)
|
||||
target_link_libraries(torch_triton torch triton)
|
||||
endif()
|
||||
|
72
examples/python/pytorch/batchnorm.cpp
Normal file
72
examples/python/pytorch/batchnorm.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/batchnorm.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
batchnorm_ymv(const torch::Tensor fw_x,
|
||||
const torch::Tensor fw_g,
|
||||
const torch::Tensor fw_b,
|
||||
float eps) {
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// get sizes
|
||||
int C = fw_x.size(0);
|
||||
int H = fw_x.size(1);
|
||||
int W = fw_x.size(2);
|
||||
int B = fw_x.size(3);
|
||||
// allocate outputs
|
||||
torch::Tensor fw_y = torch::empty(fw_x.sizes()).cuda();
|
||||
torch::Tensor fw_m = torch::empty(fw_g.sizes()).cuda();
|
||||
torch::Tensor fw_v = torch::empty(fw_g.sizes()).cuda();
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.storage().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.storage().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y.storage().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
||||
// create template
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
||||
return {fw_y, fw_m, fw_v};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
batchnorm_dxdgdb(const torch::Tensor fw_dy,
|
||||
const torch::Tensor fw_x,
|
||||
const torch::Tensor fw_g,
|
||||
const torch::Tensor fw_m,
|
||||
const torch::Tensor fw_v,
|
||||
float eps) {
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = fw_x.storage().device().index();
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// get sizes
|
||||
int C = fw_x.size(0);
|
||||
int H = fw_x.size(1);
|
||||
int W = fw_x.size(2);
|
||||
int B = fw_x.size(3);
|
||||
// allocate outputs
|
||||
torch::Tensor fw_dx = torch::empty(fw_x.sizes()).cuda();
|
||||
torch::Tensor fw_dg = torch::empty(fw_g.sizes()).cuda();
|
||||
torch::Tensor fw_db = torch::empty(fw_g.sizes()).cuda();
|
||||
// triton handles
|
||||
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.storage().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr) fw_x.storage().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr) fw_g.storage().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr) fw_m.storage().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr) fw_v.storage().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx.storage().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
}
|
@@ -12,16 +12,6 @@
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
triton::dnn::conv::type, bool> conv_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::dnn::conv>> m_config;
|
||||
|
||||
torch::Tensor conv_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t NF,
|
||||
@@ -31,95 +21,34 @@ torch::Tensor conv_common(
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
if(m_stream.find(custream) == m_stream.end())
|
||||
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_stream.at(custream).get();
|
||||
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
|
||||
// Get configuration
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// Get template
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
|
||||
triton::dnn::conv* configuration;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
configuration = m_config.emplace(key, new triton::dnn::conv(
|
||||
B, C, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp32", "fp32", ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
|
||||
triton::dnn::conv conv(B, C, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp32", "fp32", ty, has_bias);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
std::vector<int32_t> c_shapes = conv.c_shapes();
|
||||
torch::Tensor torchc;
|
||||
if(ty == triton::dnn::conv::WGRAD)
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
else
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_jit.find(key) == m_jit.end()){
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit->get_int("GZ");
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("conv", src.c_str(), benchmark);
|
||||
jit->add_module("conv", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
triton::jit::launch_information info = jit->get_launch_info("conv");
|
||||
unsigned GZ = jit->get_int("GZ");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
// Enqueue
|
||||
conv.enqueue(&stream, {&a, &b, &c, bias});
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
@@ -12,103 +12,111 @@
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t*, int32_t*,
|
||||
triton::dnn::shift::type, bool> shift_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_shift_stream;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::jit>> m_shift_jit;
|
||||
static std::map<shift_key_t, std::unique_ptr<triton::dnn::shift>> m_shift_config;
|
||||
|
||||
torch::Tensor shift_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t F,
|
||||
std::vector<int32_t> shift_h, std::vector<int32_t> shift_w,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w,
|
||||
triton::dnn::shift::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
if(m_shift_stream.find(custream) == m_shift_stream.end())
|
||||
stream = m_shift_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_shift_stream.at(custream).get();
|
||||
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
|
||||
triton::driver::cu_stream stream(custream, false);
|
||||
triton::driver::context* ctx = stream.context();
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
shift_key_t key = {B, C, D, H, W, T, R, S, F, shift_h.data(), shift_w.data(), ty, has_bias};
|
||||
triton::dnn::shift* configuration;
|
||||
if(m_shift_config.find(key) == m_shift_config.end())
|
||||
configuration = m_shift_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R, S, F,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_shift_config.at(key).get();
|
||||
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R, S, F,
|
||||
stride_h, stride_w,
|
||||
shift_h, shift_w, "fp32", "fp32",
|
||||
ty, has_bias);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
std::vector<int32_t> c_shapes = shift.c_shapes();
|
||||
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_shift_jit.find(key) == m_shift_jit.end()){
|
||||
jit = m_shift_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given shiftolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("shift", src.c_str(), benchmark);
|
||||
jit->add_module("shift", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("shift", src.c_str(), jit->get_valid("shift", src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
configuration->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_shift_jit.at(key).get();
|
||||
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("shift");
|
||||
triton::jit::launch_information info = jit->get_launch_info("shift");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// enqueue
|
||||
configuration->enqueue_impl(stream, kernel, &a, &b, &c, TM, TN, nthreads);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
return torchc;
|
||||
}
|
||||
|
||||
torch::Tensor shift_y(
|
||||
const torch::Tensor x,
|
||||
const torch::Tensor w,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
// shapes for a
|
||||
int64_t Ca = x.size(0);
|
||||
int64_t H = x.size(1);
|
||||
int64_t W = x.size(2);
|
||||
int64_t B = x.size(3);
|
||||
// shapes for b
|
||||
int64_t Cb = w.size(0);
|
||||
int64_t F = w.size(1);
|
||||
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
|
||||
int64_t C = Ca;
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dx(
|
||||
const torch::Tensor dy,
|
||||
const torch::Tensor w,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
// shapes for a
|
||||
int64_t Ca = dy.size(0);
|
||||
int64_t H = dy.size(1);
|
||||
int64_t W = dy.size(2);
|
||||
int64_t B = dy.size(3);
|
||||
H *= stride_h;
|
||||
W *= stride_w;
|
||||
// shapes for b
|
||||
int64_t Cb = w.size(0);
|
||||
int64_t F = w.size(1);
|
||||
std::swap(Cb, F);
|
||||
// checks
|
||||
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
|
||||
int64_t C = Ca;
|
||||
std::swap(C, F);
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dw(
|
||||
const torch::Tensor dy,
|
||||
const torch::Tensor x,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
// shapes for a
|
||||
int64_t F = dy.size(0);
|
||||
int64_t Ha = dy.size(1);
|
||||
int64_t Wa = dy.size(2);
|
||||
int64_t Ba = dy.size(3);
|
||||
// shapes for b
|
||||
int64_t C = x.size(0);
|
||||
int64_t Hb = x.size(1);
|
||||
int64_t Wb = x.size(2);
|
||||
int64_t Bb = x.size(3);
|
||||
// check
|
||||
AT_CHECK(Ha*stride_h == Hb, "operands must have the same image height");
|
||||
AT_CHECK(Wa*stride_w == Wb, "operands must have the same image width");
|
||||
AT_CHECK(Ba == Bb, "operands must have the same batch size");
|
||||
int64_t H = Hb;
|
||||
int64_t W = Wb;
|
||||
int64_t B = Bb;
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
|
||||
}
|
||||
|
@@ -35,31 +35,31 @@ public:
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& x = context->input(0);
|
||||
const Tensor& g = context->input(1);
|
||||
const Tensor& b = context->input(2);
|
||||
const Tensor& fw_x = context->input(0);
|
||||
const Tensor& fw_g = context->input(1);
|
||||
const Tensor& fw_b = context->input(2);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
int C = fw_x.dim_size(0);
|
||||
int H = fw_x.dim_size(1);
|
||||
int W = fw_x.dim_size(2);
|
||||
int B = fw_x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* y = nullptr;
|
||||
Tensor* m = nullptr;
|
||||
Tensor* v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &v));
|
||||
Tensor* fw_y = nullptr;
|
||||
Tensor* fw_m = nullptr;
|
||||
Tensor* fw_v = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, fw_x.shape(), &fw_y));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_m));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_v));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tb(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer ty(ctx, (CUdeviceptr)y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)fw_b.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer y(ctx, (CUdeviceptr)fw_y->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&ty, &tm, &tv, &tx, &tg, &tb});
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -99,35 +99,35 @@ public:
|
||||
triton::driver::context* ctx = sstream.context();
|
||||
triton::driver::stream* stream = &sstream;
|
||||
// get inputs
|
||||
const Tensor& dy = context->input(0);
|
||||
const Tensor& x = context->input(1);
|
||||
const Tensor& g = context->input(2);
|
||||
const Tensor& m = context->input(3);
|
||||
const Tensor& v = context->input(4);
|
||||
const Tensor& fw_dy = context->input(0);
|
||||
const Tensor& fw_x = context->input(1);
|
||||
const Tensor& fw_g = context->input(2);
|
||||
const Tensor& fw_m = context->input(3);
|
||||
const Tensor& fw_v = context->input(4);
|
||||
// get sizes
|
||||
int C = x.dim_size(0);
|
||||
int H = x.dim_size(1);
|
||||
int W = x.dim_size(2);
|
||||
int B = x.dim_size(3);
|
||||
int C = fw_x.dim_size(0);
|
||||
int H = fw_x.dim_size(1);
|
||||
int W = fw_x.dim_size(2);
|
||||
int B = fw_x.dim_size(3);
|
||||
// allocate outputs
|
||||
Tensor* dx = nullptr;
|
||||
Tensor* dg = nullptr;
|
||||
Tensor* db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, x.shape(), &dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, g.shape(), &dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, g.shape(), &db));
|
||||
Tensor* fw_dx = nullptr;
|
||||
Tensor* fw_dg = nullptr;
|
||||
Tensor* fw_db = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, fw_x.shape(), &fw_dx));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(1, fw_g.shape(), &fw_dg));
|
||||
OP_REQUIRES_OK(context, context->allocate_output(2, fw_g.shape(), &fw_db));
|
||||
// triton handles
|
||||
triton::driver::cu_buffer tdy(ctx, (CUdeviceptr)dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tx(ctx, (CUdeviceptr)x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tg(ctx, (CUdeviceptr)g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tm(ctx, (CUdeviceptr)m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tv(ctx, (CUdeviceptr)v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdx(ctx, (CUdeviceptr)dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdg(ctx, (CUdeviceptr)dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer tdb(ctx, (CUdeviceptr)db->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dy(ctx, (CUdeviceptr)fw_dy.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer x(ctx, (CUdeviceptr)fw_x.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer g(ctx, (CUdeviceptr)fw_g.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dx(ctx, (CUdeviceptr)fw_dx->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg->flat<float>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db->flat<float>().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
batchnorm.enqueue(stream, {&tdx, &tdg, &tdb, &tdy, &tx, &tg, &tm, &tv});
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -73,7 +73,6 @@ public:
|
||||
if(OP == triton::dnn::shift::BPROP)
|
||||
std::swap(C, F);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context){
|
||||
|
Reference in New Issue
Block a user