[triton/dnn/conv] merged optimizations branch
- Added forward/backward support for strided convolution - Added support for bias - Added support for reduction splitting
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <vector>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
@@ -26,10 +28,13 @@ torch::Tensor conv_common(
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
triton::dnn::conv::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias,
|
||||
bool autotune = false
|
||||
) {
|
||||
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
@@ -37,8 +42,10 @@ torch::Tensor conv_common(
|
||||
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_stream.at(custream).get();
|
||||
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
|
||||
// Get configuration
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
|
||||
@@ -52,22 +59,13 @@ torch::Tensor conv_common(
|
||||
ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_jit.find(key) == m_jit.end()){
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::string src = configuration->src();
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Get memory
|
||||
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
torch::Tensor torchc;
|
||||
@@ -76,17 +74,52 @@ torch::Tensor conv_common(
|
||||
else
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Add module to JIT
|
||||
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_jit.find(key) == m_jit.end()){
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
configuration->src(oss);
|
||||
std::string src = oss.str();
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit->get_int("GZ");
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads); },
|
||||
[&](){ stream->synchronize(); }, stream->context()->device());
|
||||
return configuration->get_nflops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
if(autotune) {
|
||||
triton::jit::tune_res_t best = jit->autotune("conv", src.c_str(), benchmark);
|
||||
jit->add_module("conv", src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
|
||||
// Run
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
triton::jit::launch_information info = jit->get_launch_info("conv");
|
||||
unsigned GZ = jit->get_int("GZ");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
||||
configuration->set_arg(kernel, &a, &b, &c, bias);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
// enqueue
|
||||
configuration->enqueue(stream, kernel, &a, &b, &c, bias, TM, TN, GZ, nthreads);
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
@@ -1,11 +1,14 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
bias = torch.autograd.Variable(torch.randn(64).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(precision=4)
|
||||
|
||||
x = torch.autograd.Variable(torch.randn(64, 3, 8, 8).cuda(), requires_grad=True)
|
||||
bias = torch.autograd.Variable(torch.randn(6).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(3, 3, 3, 6).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 6, 6).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(64, 6, 8, 8).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
@@ -13,13 +16,14 @@ def run(x, w, conv):
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone(), bias.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (0,0)))
|
||||
ttyloss, tty, ttdx, ttdw, ttbias = run(x, w, lambda x, w: triton.ConvFunction.apply(x, w, bias, (1,1), (1,1)))
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
bias.grad.zero_()
|
||||
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=0))
|
||||
culoss, cuy, cudx, cudw, cubias = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, bias=bias, stride=1, padding=1))
|
||||
|
||||
print(ttdx[0,0,:,:], cudx[0,0,:,:])
|
||||
print(ttdx[0,0,:,:])
|
||||
print(cudx[0,0,:,:])
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
|
@@ -2,7 +2,7 @@ import torch
|
||||
from torch.nn.modules.utils import _single, _pair, _triple
|
||||
import math
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
torch.ops.load_library("/home/philippe/development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
class ConvFunction(torch.autograd.Function):
|
||||
|
||||
@@ -37,7 +37,7 @@ class _ConvNd(torch.nn.Module):
|
||||
padding, dilation, transposed, output_padding, groups, bias):
|
||||
super(_ConvNd, self).__init__()
|
||||
# not everything is supported by Triton
|
||||
assert all(x==1 for x in stride)
|
||||
assert all(x==1 or x==2 for x in stride)
|
||||
assert all(x==1 for x in dilation)
|
||||
assert transposed == False
|
||||
assert all(x==0 for x in output_padding)
|
||||
@@ -46,6 +46,7 @@ class _ConvNd(torch.nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(
|
||||
in_channels, kernel_size[0], kernel_size[1], out_channels))
|
||||
@@ -56,7 +57,7 @@ class _ConvNd(torch.nn.Module):
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(self, input):
|
||||
return ConvFunction.apply(input, self.weight, self.bias, self.padding)
|
||||
return ConvFunction.apply(input, self.weight, self.bias, self.stride, self.padding)
|
||||
|
||||
def reset_parameters(self):
|
||||
n = self.in_channels
|
||||
|
Reference in New Issue
Block a user