This commit is contained in:
Philippe Tillet
2019-06-27 12:39:17 -07:00
parent 9028e40f1d
commit d8526669f5
2 changed files with 6 additions and 4 deletions

View File

@@ -99,7 +99,7 @@ int main() {
// shift // shift
std::vector<unsigned> params = { std::vector<unsigned> params = {
8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4 8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4
}; };
std::ostringstream oss; std::ostringstream oss;
shift.src(oss); shift.src(oss);

View File

@@ -8,8 +8,7 @@ void shift::set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld) { std::vector<int32_t>& ld) {
size_t size = shapes.size(); size_t size = shapes.size();
ld.resize(size); ld.resize(size);
ld[4] = 1; ld[3] = 1;
ld[3] = shapes[4]*ld[4];
ld[2] = shapes[3]*ld[3]; ld[2] = shapes[3]*ld[3];
ld[1] = shapes[2]*ld[2]; ld[1] = shapes[2]*ld[2];
ld[0] = shapes[1]*ld[1]; ld[0] = shapes[1]*ld[1];
@@ -42,6 +41,9 @@ shift::shift(int B, int NC,
shapes_c_ = {NF, H, W, B}; shapes_c_ = {NF, H, W, B};
// memory strides // memory strides
set_ld(shapes_a_, ld_a_); set_ld(shapes_a_, ld_a_);
// build LUTs
build_deltas();
build_masks();
} }
void shift::build_deltas() { void shift::build_deltas() {
@@ -89,7 +91,7 @@ std::vector<int32_t> shift::c_shapes(){
} }
size_t shift::get_nflops() { size_t shift::get_nflops() {
return 2 * M_ * N_ * K_; return 2. * M_ * N_ * K_;
} }