[dnn/conv] Added bias and forward stride
This commit is contained in:
@@ -14,7 +14,7 @@ typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
triton::dnn::conv::type> conv_key_t;
|
||||
triton::dnn::conv::type, bool> conv_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
@@ -26,7 +26,7 @@ torch::Tensor conv_common(
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
triton::dnn::conv::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb
|
||||
torch::Tensor torcha, torch::Tensor torchb, torch::Tensor torchbias
|
||||
) {
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
@@ -40,13 +40,16 @@ torch::Tensor conv_common(
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
// Get configuration
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
|
||||
bool has_bias = torchbias.storage().size() > 0;
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty, has_bias};
|
||||
triton::dnn::conv* configuration;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
configuration = m_config.emplace(key, new triton::dnn::conv(
|
||||
B, C, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w, ty)).first->second.get();
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
ty, has_bias)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
// Get JIT
|
||||
@@ -55,12 +58,16 @@ torch::Tensor conv_common(
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::string src = configuration->src();
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Get memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer cubias(ctx, (CUdeviceptr)torchbias.storage().data(), false);
|
||||
triton::driver::buffer* bias = has_bias ? &cubias : nullptr;
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
torch::Tensor torchc;
|
||||
@@ -76,10 +83,9 @@ torch::Tensor conv_common(
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// launch info
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
||||
configuration->set_arg(kernel, &a, &b, &c);
|
||||
configuration->set_arg(kernel, &a, &b, &c, bias);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
return torchc;
|
||||
}
|
||||
@@ -87,6 +93,8 @@ torch::Tensor conv_common(
|
||||
torch::Tensor conv_fprop(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor bias,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
int64_t pad_h, int64_t pad_w) {
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
@@ -104,16 +112,19 @@ torch::Tensor conv_fprop(
|
||||
const int32_t S = weight.size(2);
|
||||
const int32_t NF = weight.size(3);
|
||||
// Configuration
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t stride_d = 1;
|
||||
const int32_t pad_d = 0;
|
||||
// Check
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight, bias);
|
||||
}
|
||||
|
||||
torch::Tensor conv_bprop(
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor weight,
|
||||
const torch::Tensor bias,
|
||||
int64_t H, int64_t W,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
int64_t pad_h, int64_t pad_w){
|
||||
// Check
|
||||
CHECK_INPUT(derror);
|
||||
@@ -131,22 +142,20 @@ torch::Tensor conv_bprop(
|
||||
const int32_t S = weight.size(2);
|
||||
const int32_t Kw = weight.size(3);
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t stride_d = 1;
|
||||
int32_t pad_d = 0;
|
||||
const int32_t D = (M*stride_d + T - 1 - 2*pad_d - stride_d + 1) / upsample_d;
|
||||
const int32_t H = (P*stride_d + R - 1 - 2*pad_h - stride_h + 1) / upsample_h;
|
||||
const int32_t W = (Q*stride_d + S - 1 - 2*pad_w - stride_w + 1) / upsample_w;
|
||||
|
||||
|
||||
int32_t D = 1;
|
||||
// Check
|
||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight, bias);
|
||||
}
|
||||
|
||||
torch::Tensor conv_wgrad(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor bias,
|
||||
int64_t R, int64_t S,
|
||||
int64_t stride_h, int64_t stride_w,
|
||||
int64_t pad_h, int64_t pad_w
|
||||
){
|
||||
// Check
|
||||
@@ -166,16 +175,12 @@ torch::Tensor conv_wgrad(
|
||||
const int32_t Q = derror.size(3);
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t stride_d = 1;
|
||||
const int32_t pad_d = 0;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d + stride_d - 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h + stride_h - 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w + stride_w - 1)*upsample_w;
|
||||
|
||||
|
||||
const int32_t T = 1;
|
||||
// Check
|
||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror, bias);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
|
Reference in New Issue
Block a user