This commit is contained in:
Philippe Tillet
2019-06-27 12:39:17 -07:00
parent 9028e40f1d
commit d8526669f5
2 changed files with 6 additions and 4 deletions

View File

@@ -99,7 +99,7 @@ int main() {
// shift
std::vector<unsigned> params = {
8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4
8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4
};
std::ostringstream oss;
shift.src(oss);

View File

@@ -8,8 +8,7 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) {
size_t size = shapes.size();
ld.resize(size);
ld[4] = 1;
ld[3] = shapes[4]*ld[4];
ld[3] = 1;
ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1];
@@ -42,6 +41,9 @@ shift::shift(int B, int NC,
shapes_c_ = {NF, H, W, B};
// memory strides
set_ld(shapes_a_, ld_a_);
// build LUTs
build_deltas();
build_masks();
}
void shift::build_deltas() {
@@ -89,7 +91,7 @@ std::vector<int32_t> shift::c_shapes(){
}
size_t shift::get_nflops() {
return 2 * M_ * N_ * K_;
return 2. * M_ * N_ * K_;
}