fixup
This commit is contained in:
@@ -99,7 +99,7 @@ int main() {
|
|||||||
|
|
||||||
// shift
|
// shift
|
||||||
std::vector<unsigned> params = {
|
std::vector<unsigned> params = {
|
||||||
8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4
|
8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4
|
||||||
};
|
};
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
shift.src(oss);
|
shift.src(oss);
|
||||||
|
@@ -8,8 +8,7 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
|
|||||||
std::vector<int32_t>& ld) {
|
std::vector<int32_t>& ld) {
|
||||||
size_t size = shapes.size();
|
size_t size = shapes.size();
|
||||||
ld.resize(size);
|
ld.resize(size);
|
||||||
ld[4] = 1;
|
ld[3] = 1;
|
||||||
ld[3] = shapes[4]*ld[4];
|
|
||||||
ld[2] = shapes[3]*ld[3];
|
ld[2] = shapes[3]*ld[3];
|
||||||
ld[1] = shapes[2]*ld[2];
|
ld[1] = shapes[2]*ld[2];
|
||||||
ld[0] = shapes[1]*ld[1];
|
ld[0] = shapes[1]*ld[1];
|
||||||
@@ -42,6 +41,9 @@ shift::shift(int B, int NC,
|
|||||||
shapes_c_ = {NF, H, W, B};
|
shapes_c_ = {NF, H, W, B};
|
||||||
// memory strides
|
// memory strides
|
||||||
set_ld(shapes_a_, ld_a_);
|
set_ld(shapes_a_, ld_a_);
|
||||||
|
// build LUTs
|
||||||
|
build_deltas();
|
||||||
|
build_masks();
|
||||||
}
|
}
|
||||||
|
|
||||||
void shift::build_deltas() {
|
void shift::build_deltas() {
|
||||||
@@ -89,7 +91,7 @@ std::vector<int32_t> shift::c_shapes(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t shift::get_nflops() {
|
size_t shift::get_nflops() {
|
||||||
return 2 * M_ * N_ * K_;
|
return 2. * M_ * N_ * K_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user