[dnn] added Triton-C derivative computations in conv
This commit is contained in:
@@ -10,7 +10,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
@@ -77,7 +77,7 @@ int main() {
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
@@ -4,170 +4,69 @@
|
||||
#include <vector>
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AN, int32 AH, int32 AW,
|
||||
int32 CN, int32 CK, int32 CP, int32 CQ,
|
||||
int32 AC, int32 AR, int32 AS,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w,
|
||||
int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ranh[TM] = rxa / CQ;
|
||||
int32 raw[TM] = rxa % CQ - pad_w;
|
||||
int32 ran[TM] = ranh / CP;
|
||||
int32 rah[TM] = ranh % CP - pad_h;
|
||||
int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / AS;
|
||||
int32 ras[TK] = rka % AS;
|
||||
int32 rac[TK] = racr / AR;
|
||||
int32 rar[TK] = racr % AR;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + AR*AS + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*CK;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CP*CQ);
|
||||
int32 rcpq[TM] = rxc % (CP*CQ);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
void build_conv_lut(int TK,
|
||||
int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int>& res, std::vector<int>& masks) {
|
||||
/* convolution parameters */
|
||||
int F = T * R * S;
|
||||
int Nlut = (TK + F - 1) / F * F;
|
||||
int upsample_w = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_d = 1;
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S;
|
||||
int32_t s = trs - tr*S;
|
||||
int32_t t = tr / R;
|
||||
int32_t r = tr - t*R;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
res[i] = (((i + TK) % Nlut) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Nlut;
|
||||
size_t Ds1 = upsample_w;
|
||||
size_t Ds2 = upsample_h;
|
||||
size_t Ds3 = upsample_d;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / F;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % F);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK;
|
||||
int32_t nextc = nextctrs / F;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
|
||||
}
|
||||
torch::Tensor conv_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t NF,
|
||||
int32_t stride_d, int32_t stride_h, int32_t stride_w,
|
||||
int32_t pad_d, int32_t pad_h, int32_t pad_w,
|
||||
triton::dnn::conv::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb
|
||||
) {
|
||||
// Configuration
|
||||
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty);
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration.c_shapes();
|
||||
torch::Tensor torchc;
|
||||
if(ty == triton::dnn::conv::WGRAD)
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
else
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torchc.storage().device().index();
|
||||
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
stream->synchronize();
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::string src = configuration.src();
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// initialize constant memory
|
||||
if(ty != triton::dnn::conv::WGRAD){
|
||||
std::vector<int> h_delta;
|
||||
std::vector<int> h_masks;
|
||||
configuration.build_deltas(h_delta);
|
||||
configuration.build_masks(h_masks);
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Nlut;
|
||||
size_t Ms1 = 2*pad_w + 1;
|
||||
size_t Ms2 = 2*pad_h + 1;
|
||||
size_t Ms3 = 2*pad_d + 1;
|
||||
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % F);
|
||||
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
|
||||
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
|
||||
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
masks[i] = 0x0;
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
configuration.set_arg(kernel, &a, &b, &c);
|
||||
stream->synchronize();
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
return torchc;
|
||||
}
|
||||
|
||||
torch::Tensor conv_forward(
|
||||
torch::Tensor conv_fprop(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor weight) {
|
||||
// Check
|
||||
@@ -176,6 +75,7 @@ torch::Tensor conv_forward(
|
||||
// Unpack data shapes
|
||||
const int32_t B = data.size(0);
|
||||
const int32_t Ci = data.size(1);
|
||||
const int32_t D = 1;
|
||||
const int32_t H = data.size(2);
|
||||
const int32_t W = data.size(3);
|
||||
// Unpack weight shapes
|
||||
@@ -184,109 +84,76 @@ torch::Tensor conv_forward(
|
||||
const int32_t R = weight.size(1);
|
||||
const int32_t S = weight.size(2);
|
||||
const int32_t NF = weight.size(3);
|
||||
// Conv parameters
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_h = 1, stride_w = 1;
|
||||
// Output shapes
|
||||
int32_t P = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t Q = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
// Allocate output
|
||||
// Configuration
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
// Check
|
||||
AT_CHECK(Ci == Cf, "Number of channels in data and weights must match");
|
||||
torch::Tensor output = torch::empty({B, NF, P, Q}, torch::kFloat).cuda();
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = output.storage().device().index();
|
||||
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false);
|
||||
triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false);
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)output.storage().data(), false);
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 1, 8,
|
||||
4
|
||||
};
|
||||
jit.add_module("conv", src, params);
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned TK = jit.get_int("TK");
|
||||
// initialize constant memory
|
||||
int FS = T*R*S;
|
||||
int nlut = (TK + FS - 1) / FS * FS;
|
||||
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
// memory stride for images
|
||||
int32_t stride_i_w = 1;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_d = H*stride_i_h;
|
||||
int32_t stride_i_c = 1*stride_i_d;
|
||||
int32_t stride_i_n = Ci*stride_i_c;
|
||||
// memory stride for activations
|
||||
int32_t stride_o_q = 1;
|
||||
int32_t stride_o_p = Q*stride_o_q;
|
||||
int32_t stride_o_m = P*stride_o_p;
|
||||
int32_t stride_o_k = 1*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = B*P*Q;
|
||||
int32_t N = NF;
|
||||
int32_t K = Ci*R*S;
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set arguments
|
||||
kernel->setArg(0, *d.cu());
|
||||
kernel->setArg(1, *w.cu());
|
||||
kernel->setArg(2, *a.cu());
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, NF);
|
||||
kernel->setArg(10, P);
|
||||
kernel->setArg(11, Q);
|
||||
kernel->setArg(12, Ci);
|
||||
kernel->setArg(13, R);
|
||||
kernel->setArg(14, S);
|
||||
kernel->setArg(15, stride_i_n);
|
||||
kernel->setArg(16, stride_i_c);
|
||||
kernel->setArg(17, stride_i_h);
|
||||
kernel->setArg(18, stride_i_w);
|
||||
kernel->setArg(19, stride_o_n);
|
||||
kernel->setArg(20, stride_o_k);
|
||||
kernel->setArg(21, stride_o_p);
|
||||
kernel->setArg(22, stride_o_q);
|
||||
kernel->setArg(23, pad_h);
|
||||
kernel->setArg(24, pad_w);
|
||||
kernel->setArg(25, bound);
|
||||
// // dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
return output;
|
||||
return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight);
|
||||
}
|
||||
|
||||
torch::Tensor conv_bprop(
|
||||
const torch::Tensor derror,
|
||||
const torch::Tensor weight){
|
||||
// Check
|
||||
CHECK_INPUT(derror);
|
||||
CHECK_INPUT(weight);
|
||||
// Unpack data shapes
|
||||
const int32_t B = derror.size(0);
|
||||
const int32_t Ki = derror.size(1);
|
||||
const int32_t M = 1;
|
||||
const int32_t P = derror.size(2);
|
||||
const int32_t Q = derror.size(3);
|
||||
// Unpack weight shapes
|
||||
const int32_t C = weight.size(0);
|
||||
const int32_t T = 1;
|
||||
const int32_t R = weight.size(1);
|
||||
const int32_t S = weight.size(2);
|
||||
const int32_t Kw = weight.size(3);
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d;
|
||||
const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h;
|
||||
const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w;
|
||||
// Check
|
||||
AT_CHECK(Ki == Kw, "Number of channels in error and weights must match");
|
||||
return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight);
|
||||
}
|
||||
|
||||
torch::Tensor conv_wgrad(
|
||||
const torch::Tensor data,
|
||||
const torch::Tensor derror
|
||||
){
|
||||
// Check
|
||||
CHECK_INPUT(data);
|
||||
CHECK_INPUT(derror);
|
||||
// Unpack data shapes
|
||||
const int32_t Ba = derror.size(0);
|
||||
const int32_t C = derror.size(1);
|
||||
const int32_t D = 1;
|
||||
const int32_t H = derror.size(2);
|
||||
const int32_t W = derror.size(3);
|
||||
// Unpack error shapes
|
||||
const int32_t Bb = derror.size(0);
|
||||
const int32_t K = derror.size(1);
|
||||
const int32_t M = 1;
|
||||
const int32_t P = derror.size(2);
|
||||
const int32_t Q = derror.size(3);
|
||||
// Compute M, P, Q
|
||||
const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
const int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
const int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d;
|
||||
const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h;
|
||||
const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w;
|
||||
// Check
|
||||
AT_CHECK(Ba == Bb, "Number of channels in error and weights must match");
|
||||
return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror);
|
||||
}
|
||||
|
||||
static auto registry =
|
||||
torch::jit::RegisterOperators("triton::conv_forward", &conv_forward);
|
||||
torch::jit::RegisterOperators("triton::conv_fprop", &conv_fprop)
|
||||
.op("triton::conv_bprop", &conv_bprop)
|
||||
.op("triton::conv_wgrad", &conv_wgrad);
|
||||
|
@@ -1,9 +1,48 @@
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
torch.manual_seed(0)
|
||||
|
||||
class TritonConv(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
output = torch.ops.triton.conv_fprop(input, weight)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous())
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so")
|
||||
|
||||
d = torch.empty(64, 64, 64, 64).uniform_(0, 1).cuda()
|
||||
w = torch.empty(64, 3, 3, 64).uniform_(0, 1).cuda()
|
||||
a = torch.ops.triton.conv_forward(d, w)
|
||||
print(a)
|
||||
x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True)
|
||||
cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True)
|
||||
y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True)
|
||||
|
||||
def run(x, w, conv):
|
||||
y = conv(x, w)
|
||||
loss = (y - y_target).norm(2)
|
||||
loss.backward()
|
||||
return loss, y.clone(), x.grad.clone(), w.grad.clone()
|
||||
|
||||
ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply)
|
||||
x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
#print(ttdx)
|
||||
#print(cudx)
|
||||
#print(ttdw)
|
||||
#print(cudw)
|
||||
#print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
|
@@ -20,12 +20,12 @@ public:
|
||||
conv(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
type ty = FPROP)
|
||||
: NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF),
|
||||
upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w),
|
||||
stride_d_(1), stride_h_(1), stride_w_(1),
|
||||
stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w),
|
||||
upsample_d_(1), upsample_h_(1), upsample_w_(1),
|
||||
pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w),
|
||||
ty_(ty)
|
||||
{
|
||||
@@ -93,6 +93,10 @@ public:
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
std::vector<int32_t> c_shapes() {
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
void build_deltas(std::vector<int>& deltas){
|
||||
if(ty_ == WGRAD)
|
||||
throw std::runtime_error("no look-up table necessary for wgrad");
|
||||
@@ -120,6 +124,7 @@ public:
|
||||
int32_t c = ctrs / Fs_;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % Fs_);
|
||||
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK_;
|
||||
int32_t nextc = nextctrs / Fs_;
|
||||
@@ -223,6 +228,43 @@ public:
|
||||
|
||||
|
||||
std::string xprop() {
|
||||
|
||||
std::string declare_pb;
|
||||
if(ty_ == FPROP){
|
||||
declare_pb = R"(
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*ldb_s + rb0[:, newaxis];
|
||||
)";
|
||||
}
|
||||
else{
|
||||
declare_pb = R"(
|
||||
fp32* pb_base[TN, TK] = b + rb0[:, newaxis]*ldb_c;
|
||||
int32 rbk[TK] = rkb / (BH*BW);
|
||||
int32 rbrs[TK] = rkb % (BH*BW);
|
||||
int32 rbs[TK] = BW - 1 - rbrs % BW;
|
||||
int32 rbr[TK] = BH - 1 - rbrs / BW;
|
||||
int32 rb1[TK] = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s;
|
||||
fp32* pb[TN, TK] = pb_base + rb1[newaxis, :];
|
||||
)";
|
||||
}
|
||||
std::string increment_pb;
|
||||
if(ty_ == FPROP){
|
||||
increment_pb = R"(
|
||||
pb = pb + TK*ldb_s;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
increment_pb = R"(
|
||||
rbrs = rbrs + TK;
|
||||
rkb = rkb + TK;
|
||||
rbk = rkb / (BH*BW);
|
||||
rbrs = rkb % (BH*BW);
|
||||
rbs = BW - 1 - rbrs % BW;
|
||||
rbr = BH - 1 - rbrs / BW;
|
||||
rb1 = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s;
|
||||
pb = pb_base + rb1[newaxis, :];
|
||||
)";
|
||||
}
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
@@ -246,7 +288,7 @@ public:
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
@@ -258,8 +300,8 @@ public:
|
||||
int32 rac[TK] = racr / BH;
|
||||
int32 rar[TK] = racr % BH;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*ldb_s + rb0[:, newaxis];
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"
|
||||
+ declare_pb + R"(
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + BH*BW + rka;
|
||||
int32 d[TK] = *pd;
|
||||
@@ -276,8 +318,8 @@ public:
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*ldb_s;
|
||||
pa = pa + d[newaxis, :];
|
||||
pa = pa + d[newaxis, :];)"
|
||||
+ increment_pb + R"(
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
@@ -288,6 +330,7 @@ public:
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
checka = checka && (k > TK);
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
@@ -379,7 +422,7 @@ public:
|
||||
{
|
||||
IN_DTYPE acc;
|
||||
for(int32_t n = 0; n < shapes_c_[0]; ++n)
|
||||
for(int32_t k = 0; k < shapes_c_[1] ; ++k)
|
||||
for(int32_t cf = 0; cf < shapes_c_[1] ; ++cf)
|
||||
for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd)
|
||||
for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch)
|
||||
for(int32_t cw = 0; cw < shapes_c_[4]; ++cw)
|
||||
@@ -388,7 +431,7 @@ public:
|
||||
int32_t d = cd*stride_d_ - pad_d_;
|
||||
int32_t h = ch*stride_h_ - pad_h_;
|
||||
int32_t w = cw*stride_w_ - pad_w_;
|
||||
for(int32_t c = 0; c < shapes_b_[0]; ++c)
|
||||
for(int32_t ac = 0; ac < shapes_a_[1]; ++ac)
|
||||
for(int32_t bd = 0; bd < shapes_b_[1]; ++bd)
|
||||
for(int32_t bh = 0; bh < shapes_b_[2]; ++bh)
|
||||
for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){
|
||||
@@ -400,11 +443,19 @@ public:
|
||||
aw >= 0 && aw < shapes_a_[4]);
|
||||
IN_DTYPE a = 0;
|
||||
if(in_bounds)
|
||||
a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b = B[c*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + k*ld_b_[4]];
|
||||
a = A[n*ld_a_[0] + ac*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]];
|
||||
IN_DTYPE b;
|
||||
if(ty_==FPROP)
|
||||
b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]];
|
||||
else{
|
||||
int32_t bdd = bd;
|
||||
int32_t bhh = bh;
|
||||
int32_t bww = bw;
|
||||
b = B[cf*ld_b_[0] + bdd*ld_b_[1] + bhh*ld_b_[2] + bww*ld_b_[3] + ac*ld_b_[4]];
|
||||
}
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
C[n*ld_c_[0] + k*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc;
|
||||
C[n*ld_c_[0] + cf*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -193,12 +193,20 @@ public:
|
||||
static cudnnStatus_t cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc, int pad_h, int pad_w, int u, int v, int upscalex, int upscaley, cudnnConvolutionMode_t mode);
|
||||
static cudnnStatus_t cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, int arrayLength, const int padA[], const int filterStrideA[], const int upscaleA[], cudnnConvolutionMode_t mode, cudnnDataType_t dataType);
|
||||
static cudnnStatus_t cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, const int windowDimA[], const int paddingA[], const int strideA[]);
|
||||
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
||||
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
// pooling
|
||||
static cudnnStatus_t cudnnPoolingForward(cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
// forward
|
||||
static cudnnStatus_t cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionFwdAlgo_t *algo);
|
||||
static cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, size_t *sizeInBytes);
|
||||
static cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const cudnnFilterDescriptor_t wDesc, const void *w, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, void *workSpace, size_t workSpaceSizeInBytes, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
static cudnnStatus_t cudnnPoolingForward(cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId);
|
||||
static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y);
|
||||
// backward data
|
||||
static cudnnStatus_t cudnnConvolutionBackwardData(cudnnHandle_t handle, const void *alpha, const cudnnFilterDescriptor_t wDesc, const void *w, const cudnnTensorDescriptor_t dyDesc, const void *dy, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionBwdDataAlgo_t algo, void* workSpace, size_t workSpaceSizeInBytes, const void* beta, const cudnnTensorDescriptor_t dxDesc, void *dx);
|
||||
static cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc,const cudnnTensorDescriptor_t dyDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionBwdDataAlgo_t* algo);
|
||||
// backward filter
|
||||
static cudnnStatus_t cudnnConvolutionBackwardFilter(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc,const void *x, const cudnnTensorDescriptor_t dyDesc,const void *dy, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionBwdFilterAlgo_t algo,void* workSpace, size_t workSpaceSizeInBytes, const void* beta, const cudnnFilterDescriptor_t dwDesc, void *dw);
|
||||
static cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t dyDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnFilterDescriptor_t dwDesc, cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionBwdFilterAlgo_t* algo);
|
||||
|
||||
// SPIR-V libraries
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
|
Reference in New Issue
Block a user