[triton/python/conv]: Added cache for compiled kernels
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <vector>
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
@@ -10,6 +10,16 @@
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
triton::dnn::conv::type> conv_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::dnn::conv>> m_config;
|
||||
|
||||
torch::Tensor conv_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t NF,
|
||||
@@ -18,41 +28,59 @@ torch::Tensor conv_common(
|
||||
triton::dnn::conv::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb
|
||||
) {
|
||||
// Configuration
|
||||
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty);
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
if(m_stream.find(custream) == m_stream.end())
|
||||
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_stream.at(custream).get();
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
// Get configuration
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
|
||||
triton::dnn::conv* configuration;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
configuration = m_config.emplace(key, new triton::dnn::conv(
|
||||
B, C, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w, ty)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_jit.find(key) == m_jit.end()){
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::string src = configuration->src();
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Get memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration.c_shapes();
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
torch::Tensor torchc;
|
||||
if(ty == triton::dnn::conv::WGRAD)
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
else
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torchc.storage().device().index();
|
||||
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
stream->synchronize();
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::string src = configuration.src();
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
// Add module to JIT
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
triton::jit::launch_information info = jit->get_launch_info("conv");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// launch info
|
||||
configuration.init(stream, jit);
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
configuration.set_arg(kernel, &a, &b, &c);
|
||||
stream->synchronize();
|
||||
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
||||
configuration->set_arg(kernel, &a, &b, &c);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user