[codegen] shift: added sketch for shift-convolution backpropagation

This commit is contained in:
Philippe Tillet
2019-07-02 16:39:07 -07:00
parent 6cfb575d29
commit 8fc253946c
16 changed files with 231 additions and 107 deletions

View File

@@ -19,8 +19,8 @@ int main() {
// initialization
int32_t R = 3, S = 3;
int32_t BS = 32, F = 1024;
int32_t H = 32, W = 32;
int32_t BS = 4, F = 1024;
int32_t H = 16, W = 16;
int32_t C = 1024;
// random shifts
@@ -31,7 +31,7 @@ int main() {
shift_w[c] = rand() % S - S/2;
}
// configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
// host buffers
std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size());