This commit is contained in:
Philippe Tillet
2019-06-28 11:13:36 -07:00
parent f4dedb522c
commit 21fd0fd65e
3 changed files with 9 additions and 5 deletions

View File

@@ -68,12 +68,12 @@ int main() {
// shift // shift
std::vector<unsigned> params = { std::vector<unsigned> params = {
8, 2, 32, 8, 2, 64, 8, 4, 2, 2, 4, 2, 8, 4 4, 2, 32, 8, 2, 32, 8, 4, 2, 2, 8, 8, 4
}; };
std::ostringstream oss; std::ostringstream oss;
shift.src(oss); shift.src(oss);
std::string src = oss.str(); std::string src = oss.str();
jit.autotune("shift", src.c_str(), benchmark); // jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params); jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift"); triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift"); triton::jit::launch_information info = jit.get_launch_info("shift");
@@ -81,7 +81,7 @@ int main() {
stream->read(dc, true, 0, hc); stream->read(dc, true, 0, hc);
shift.cpu_ref(rc.data(), ha.data(), hb.data()); shift.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++) for(size_t i = 0; i < hc.size(); i++)
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }

View File

@@ -92,10 +92,11 @@ public:
for(int32_t c = 0; c < NC_; ++c){ for(int32_t c = 0; c < NC_; ++c){
int32_t h = p; int32_t h = p;
int32_t w = q; int32_t w = q;
if(h >= BH_/2 && h < AH_ - BH_/2) if(h >= BH_/2 && h < AH_ - BH_/2
&& w >= BW_/2 && w < AW_ - BW_/2){
h += shift_h_[c]; h += shift_h_[c];
if(w > BW_/2 && w < AW_ - BW_/2)
w += shift_w_[c]; w += shift_w_[c];
}
IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]; IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_];
IN_DTYPE b = F[k + c*NF_]; IN_DTYPE b = F[k + c*NF_];
acc = std::fma(a, b, acc); acc = std::fma(a, b, acc);

View File

@@ -53,6 +53,9 @@ void shift::build_deltas() {
h_deltas_[c] += shift_h_[c]*ld_a_[1]; h_deltas_[c] += shift_h_[c]*ld_a_[1];
h_deltas_[c] += shift_w_[c]*ld_a_[2]; h_deltas_[c] += shift_w_[c]*ld_a_[2];
} }
for(unsigned c = 0; c < NC_; c++){
h_deltas_[c + 256] = c*ld_a_[0];
}
} }
void shift::build_masks() { void shift::build_masks() {