[codegen] shift: added sketch for shift-convolution backpropagation
This commit is contained in:
@@ -19,8 +19,8 @@ int main() {
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t BS = 32, F = 1024;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t BS = 4, F = 1024;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 1024;
|
||||
|
||||
// random shifts
|
||||
@@ -31,7 +31,7 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
|
Reference in New Issue
Block a user