[dnn/shift] added split-k for shift-conv
This commit is contained in:
@@ -18,8 +18,8 @@ int main() {
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 32, F = 128;
|
||||
int32_t H = 28, W = 28;
|
||||
int32_t B = 128, F = 128;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 128;
|
||||
|
||||
// random shifts
|
||||
@@ -44,9 +44,9 @@ int main() {
|
||||
std::swap(b_size, c_size);
|
||||
std::swap(a_size, b_size);
|
||||
}
|
||||
std::vector<NumericT> ha(B*C*H*W);
|
||||
std::vector<NumericT> hb(C*F);
|
||||
std::vector<float> hc(B*F*H*W);
|
||||
std::vector<NumericT> ha(a_size);
|
||||
std::vector<NumericT> hb(b_size);
|
||||
std::vector<float> hc(c_size);
|
||||
std::vector<float> rc(hc.size());
|
||||
// device buffers
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
|
Reference in New Issue
Block a user