[dnn/shift] added split-k for shift-conv

This commit is contained in:
Philippe Tillet
2019-07-15 21:03:58 -07:00
parent 434f65737f
commit aa8bcf6bde
9 changed files with 166 additions and 203 deletions

View File

@@ -18,8 +18,8 @@ int main() {
// initialization
int32_t R = 3, S = 3;
int32_t B = 32, F = 128;
int32_t H = 28, W = 28;
int32_t B = 128, F = 128;
int32_t H = 16, W = 16;
int32_t C = 128;
// random shifts
@@ -44,9 +44,9 @@ int main() {
std::swap(b_size, c_size);
std::swap(a_size, b_size);
}
std::vector<NumericT> ha(B*C*H*W);
std::vector<NumericT> hb(C*F);
std::vector<float> hc(B*F*H*W);
std::vector<NumericT> ha(a_size);
std::vector<NumericT> hb(b_size);
std::vector<float> hc(c_size);
std::vector<float> rc(hc.size());
// device buffers
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);