fixup
This commit is contained in:
@@ -99,7 +99,7 @@ int main() {
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4
|
||||
8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
|
@@ -8,8 +8,7 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[4] = 1;
|
||||
ld[3] = shapes[4]*ld[4];
|
||||
ld[3] = 1;
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
@@ -42,6 +41,9 @@ shift::shift(int B, int NC,
|
||||
shapes_c_ = {NF, H, W, B};
|
||||
// memory strides
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
// build LUTs
|
||||
build_deltas();
|
||||
build_masks();
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
@@ -89,7 +91,7 @@ std::vector<int32_t> shift::c_shapes(){
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() {
|
||||
return 2 * M_ * N_ * K_;
|
||||
return 2. * M_ * N_ * K_;
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user