[dnn/conv]: now using look-up table for wgrad computation as well
This commit is contained in:
@@ -38,12 +38,6 @@ int main() {
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
// look-up table
|
||||
std::vector<int> h_delta, h_masks;
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
configuration.build_deltas(h_delta);
|
||||
configuration.build_masks(h_masks);
|
||||
}
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
@@ -51,12 +45,7 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
}
|
||||
configuration.init(stream, jit);
|
||||
stream->synchronize();
|
||||
configuration.set_arg(kernel, da, db, dc);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
@@ -66,7 +55,7 @@ int main() {
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = configuration.src();
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
@@ -74,7 +63,7 @@ int main() {
|
||||
stream->read(dc, true, 0, hc);
|
||||
configuration.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
for(size_t i = 0; i < hc.size(); i++){
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
Reference in New Issue
Block a user