[dnn/conv]: now using look-up table for wgrad computation as well

This commit is contained in:
Philippe Tillet
2019-05-15 14:57:31 -04:00
parent 15a967c81e
commit ece7beea3c
3 changed files with 201 additions and 156 deletions

View File

@@ -38,12 +38,6 @@ int main() {
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
// look-up table
std::vector<int> h_delta, h_masks;
if(ty != triton::dnn::conv::WGRAD){
configuration.build_deltas(h_delta);
configuration.build_masks(h_masks);
}
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
@@ -51,12 +45,7 @@ int main() {
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
if(ty != triton::dnn::conv::WGRAD){
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
}
configuration.init(stream, jit);
stream->synchronize();
configuration.set_arg(kernel, da, db, dc);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
@@ -66,7 +55,7 @@ int main() {
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = configuration.src();
// jit.autotune("conv", src.c_str(), benchmark);
jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
@@ -74,7 +63,7 @@ int main() {
stream->read(dc, true, 0, hc);
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
for(size_t i = 0; i < hc.size(); i++){
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}