testing a simple shiftnet

This commit is contained in:
Philippe Tillet
2019-07-10 13:33:08 -07:00
parent 3b89bc8463
commit f665c742f9
6 changed files with 261 additions and 31 deletions

View File

@@ -35,8 +35,11 @@ torch::Tensor shift_common(
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
// Allocate output
std::vector<int32_t> c_shapes = shift.c_shapes();
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
std::vector<int32_t> _c_shapes = shift.c_shapes();
std::vector<long int> c_shapes;
for(auto x: _c_shapes)
c_shapes.push_back(x);
torch::Tensor torchc = torch::empty(c_shapes).cuda();
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue
shift.enqueue(&stream, {&a, &b, &c});
@@ -47,9 +50,9 @@ torch::Tensor shift_y(
const torch::Tensor x,
const torch::Tensor w,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t Ca = x.size(0);
int64_t H = x.size(1);
@@ -61,16 +64,18 @@ torch::Tensor shift_y(
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
int64_t C = Ca;
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::FPROP, x, w, bias);
}
torch::Tensor shift_dx(
const torch::Tensor dy,
const torch::Tensor w,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t Ca = dy.size(0);
int64_t H = dy.size(1);
@@ -87,16 +92,18 @@ torch::Tensor shift_dx(
int64_t C = Ca;
std::swap(C, F);
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::BPROP, dy, w, bias);
}
torch::Tensor shift_dw(
const torch::Tensor dy,
const torch::Tensor x,
const torch::Tensor bias,
int32_t R, int32_t S,
int32_t stride_h, int32_t stride_w,
int32_t* shift_h, int32_t* shift_w) {
int64_t R, int64_t S,
int64_t stride_h, int64_t stride_w,
const torch::Tensor shift_h, const torch::Tensor shift_w) {
// shapes for a
int64_t F = dy.size(0);
int64_t Ha = dy.size(1);
@@ -115,7 +122,9 @@ torch::Tensor shift_dw(
int64_t W = Wb;
int64_t B = Bb;
// run
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
triton::dnn::shift::WGRAD, dy, x, bias);
}
static auto registry =