[triton/dnn/conv] merged optimizations branch

- Added forward/backward support for strided convolution
- Added support for bias
- Added support for reduction splitting
This commit is contained in:
Philippe Tillet
2019-05-28 14:02:27 -04:00
parent e526ffc62b
commit a9d078c06f
47 changed files with 732 additions and 31339 deletions

View File

@@ -1,10 +1,12 @@
#include <vector>
#include <sstream>
#include <torch/torch.h>
#include <torch/script.h>
#include "ATen/cuda/CUDAContext.h"
#include <vector>
#include "triton/runtime/jit.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
#include "triton/tools/bench.hpp"
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
@@ -26,10 +28,13 @@ torch::Tensor conv_common(
int32_t stride_d, int32_t stride_h, int32_t stride_w,
int32_t pad_d, int32_t pad_h, int32_t pad_w,
triton::dnn::conv::type ty,
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
bool autotune = false
) {
// Wrap CUDA handles
c10::DeviceIndex device = torcha.storage().device().index();
// Get stream
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
triton::driver::stream* stream;
@@ -37,8 +42,10 @@ torch::Tensor conv_common(
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
else
stream = m_stream.at(custream).get();
// Get context
triton::driver::context* ctx = stream->context();
// Get configuration
bool has_bias = torchbias.storage().size() > 0;
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
@@ -52,22 +59,13 @@ torch::Tensor conv_common(
ty, has_bias)).first->second.get();
else
configuration = m_config.at(key).get();
// Get JIT
triton::jit* jit;
if(m_jit.find(key) == m_jit.end()){
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::string src = configuration->src();
jit->add_module("conv", src.c_str(), configuration->default_params());
triton::driver::kernel* kernel = jit->get_function("conv");
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Get memory
// Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
// Allocate output
std::vector<int32_t> c_shapes = configuration->c_shapes();
torch::Tensor torchc;
@@ -76,17 +74,52 @@ torch::Tensor conv_common(
else
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Add module to JIT
// Get JIT
triton::jit* jit;
if(m_jit.find(key) == m_jit.end()){
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss;
configuration->src(oss);
std::string src = oss.str();
// benchmark a given convolution kernel
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
unsigned GZ = jit->get_int("GZ");
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
stream->synchronize();
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
[&](){ stream->synchronize(); }, stream->context()->device());
return configuration->get_nflops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune) {
triton::jit::tune_res_t best = jit->autotune("conv", src.c_str(), benchmark);
jit->add_module("conv", src.c_str(), best.params);
}
else {
jit->add_module("conv", src.c_str(), configuration->default_params());
}
triton::driver::kernel* kernel = jit->get_function("conv");
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
}
else
jit = m_jit.at(key).get();
// Run
triton::driver::kernel* kernel = jit->get_function("conv");
triton::jit::launch_information info = jit->get_launch_info("conv");
unsigned GZ = jit->get_int("GZ");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
// launch info
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
configuration->set_arg(kernel, &a, &b, &c, bias);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
// enqueue
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
return torchc;
}