testing a simple shiftnet
This commit is contained in:
@@ -35,8 +35,11 @@ torch::Tensor shift_common(
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = shift.c_shapes();
|
||||
torch::Tensor torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[2], c_shapes[3]}).cuda();
|
||||
std::vector<int32_t> _c_shapes = shift.c_shapes();
|
||||
std::vector<long int> c_shapes;
|
||||
for(auto x: _c_shapes)
|
||||
c_shapes.push_back(x);
|
||||
torch::Tensor torchc = torch::empty(c_shapes).cuda();
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c});
|
||||
@@ -47,9 +50,9 @@ torch::Tensor shift_y(
|
||||
const torch::Tensor x,
|
||||
const torch::Tensor w,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
int64_t R, int64_t S,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||
// shapes for a
|
||||
int64_t Ca = x.size(0);
|
||||
int64_t H = x.size(1);
|
||||
@@ -61,16 +64,18 @@ torch::Tensor shift_y(
|
||||
AT_CHECK(Ca == Cb, "operands must have the same number of channels");
|
||||
int64_t C = Ca;
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::FPROP, x, w, bias);
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::FPROP, x, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dx(
|
||||
const torch::Tensor dy,
|
||||
const torch::Tensor w,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
int64_t R, int64_t S,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||
// shapes for a
|
||||
int64_t Ca = dy.size(0);
|
||||
int64_t H = dy.size(1);
|
||||
@@ -87,16 +92,18 @@ torch::Tensor shift_dx(
|
||||
int64_t C = Ca;
|
||||
std::swap(C, F);
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::BPROP, dy, w, bias);
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::BPROP, dy, w, bias);
|
||||
}
|
||||
|
||||
torch::Tensor shift_dw(
|
||||
const torch::Tensor dy,
|
||||
const torch::Tensor x,
|
||||
const torch::Tensor bias,
|
||||
int32_t R, int32_t S,
|
||||
int32_t stride_h, int32_t stride_w,
|
||||
int32_t* shift_h, int32_t* shift_w) {
|
||||
int64_t R, int64_t S,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
const torch::Tensor shift_h, const torch::Tensor shift_w) {
|
||||
// shapes for a
|
||||
int64_t F = dy.size(0);
|
||||
int64_t Ha = dy.size(1);
|
||||
@@ -115,7 +122,9 @@ torch::Tensor shift_dw(
|
||||
int64_t W = Wb;
|
||||
int64_t B = Bb;
|
||||
// run
|
||||
shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w, shift_h, shift_w, triton::dnn::shift::WGRAD, dy, x, bias);
|
||||
return shift_common(B, C, 1, H, W, 1, R, S, F, stride_h, stride_w,
|
||||
(int32_t*)shift_h.storage().data(), (int32_t*)shift_w.storage().data(),
|
||||
triton::dnn::shift::WGRAD, dy, x, bias);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
|
Reference in New Issue
Block a user