From 5941501f70a00346f7065f54217ee97c980c423d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 13 May 2019 00:38:26 -0400 Subject: [PATCH] [dnn] added Triton-C derivative computations in conv --- examples/cpp/conv.cpp | 4 +- examples/python/pytorch/conv.cpp | 383 ++++++++++--------------------- examples/python/pytorch/main.py | 49 +++- include/triton/dnn/conv.h | 77 +++++-- include/triton/driver/dispatch.h | 14 +- 5 files changed, 246 insertions(+), 281 deletions(-) diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 8cddeb588..5d4d20c7d 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -10,7 +10,7 @@ int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); - triton::dnn::conv::type ty = triton::dnn::conv::WGRAD; + triton::dnn::conv::type ty = triton::dnn::conv::FPROP; // initialization int32_t B = 4, NF = 32; int32_t D = 1, H = 24, W = 240; @@ -77,7 +77,7 @@ int main() { if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; exit(EXIT_FAILURE); - } + } } std::cout << "Pass!" << std::endl; } diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index 2aa46175d..09a3f6eaa 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -4,170 +4,69 @@ #include #include "triton/jit.h" #include "triton/driver/stream.h" +#include "triton/dnn/conv.h" #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -const char* src = -R"( -const tunable int32 TM = {16, 32, 64}; -const tunable int32 TN = {16, 32, 64}; -const tunable int32 TK = {8}; - -__constant__ int32* delta = alloc_const int32[18]; -__constant__ int32* masks = alloc_const int32[1024]; - -void conv(read_only restrict fp32 *a, - read_only restrict fp32 *b, - fp32 *c, - int32 M, int32 N, int32 K, - int32 AN, int32 AH, int32 AW, - int32 CN, int32 CK, int32 CP, int32 CQ, - int32 AC, int32 AR, int32 AS, - int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w, - int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q, - int32 pad_h, int32 pad_w, - int32 bound){ - int32 rxa[TM] = get_global_range[TM](0); - int32 rb0[TN] = get_global_range[TN](1); - int32 rka[TK] = 0 ... TK; - int32 rb1[TK] = 0 ... TK; - fp32 C[TM, TN] = 0; - int32 ranh[TM] = rxa / CQ; - int32 raw[TM] = rxa % CQ - pad_w; - int32 ran[TM] = ranh / CP; - int32 rah[TM] = ranh % CP - pad_h; - int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w; - int32 racr[TK] = rka / AS; - int32 ras[TK] = rka % AS; - int32 rac[TK] = racr / AR; - int32 rar[TK] = racr % AR; - int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis]; - __constant__ int32* pincd[TK] = delta + rka; - __constant__ int32* pd[TK] = delta + AR*AS + rka; - int32 d[TK] = *pd; - int32 incd[TK] = *pincd; - int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); - int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); - __constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1); - __constant__ int32* pincm[TM] = delta; - int32 incm[TM] = *pincm; - int32 checka0[TM] = *pm; - int32 checka1[TK] = 1 << rka; - int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; - fp32 a[TM, TK] = checka ? *pa : 0; - fp32 b[TN, TK] = *pb; - for(int32 k = K; k > 0; k = k - TK){ - C = dot(a, trans(b), C); - pb = pb + TK*CK; - pa = pa + d[newaxis, :]; - b = *pb; - pd = pd + incd; - pincd = pincd + incd; - d = *pd; - incd = *pincd; - pm = pm + incm; - pincm = pincm + incm; - incm = *pincm; - checka0 = *pm; - checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; - a = checka ? *pa : 0; - } - int32 rxc[TM] = get_global_range[TM](0); - int32 rc1[TN] = get_global_range[TN](1); - int32 rcn[TM] = rxc / (CP*CQ); - int32 rcpq[TM] = rxc % (CP*CQ); - int32 rc0[TM] = rcn * ldc_n + rcpq; - fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = rc1 < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = C; -})"; - -void build_conv_lut(int TK, - int stride_d, int stride_h, int stride_w, int stride_c, - int pad_d, int pad_h, int pad_w, - int T, int R, int S, - std::vector& res, std::vector& masks) { - /* convolution parameters */ - int F = T * R * S; - int Nlut = (TK + F - 1) / F * F; - int upsample_w = 1; - int upsample_h = 1; - int upsample_d = 1; - /* unpack index wrt filters */ - auto unpack = [&](int32_t trs){ - int32_t tr = trs / S; - int32_t s = trs - tr*S; - int32_t t = tr / R; - int32_t r = tr - t*R; - return std::make_tuple(t, r, s); - }; - /* increments */ - for(size_t i = 0; i < Nlut; ++i) - res[i] = (((i + TK) % Nlut) - i); - /* deltas */ - size_t Ds0 = Nlut; - size_t Ds1 = upsample_w; - size_t Ds2 = upsample_h; - size_t Ds3 = upsample_d; - for(size_t pd = 0; pd < Ds3; ++pd) - for(size_t ph = 0; ph < Ds2; ++ph) - for(size_t pw = 0; pw < Ds1; ++pw){ - int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; - // cumulative increments - for(size_t i = 0; i < Ds0; ++i){ - int32_t ctrs = i; - int32_t c = ctrs / F; - int32_t t, r, s; - std::tie(t, r, s) = unpack(ctrs % F); - // next indices - int32_t nextctrs = ctrs + TK; - int32_t nextc = nextctrs / F; - int32_t nextt, nextr, nexts; - std::tie(nextt, nextr, nexts) = unpack(nextctrs % F); - // diffs - int32_t cdiff = nextc - c; - int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d; - int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h; - int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w; - // delta pointers - deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d; - } +torch::Tensor conv_common( + int32_t B, int32_t C, int32_t D, int32_t H, int32_t W, + int32_t T, int32_t R, int32_t S, int32_t NF, + int32_t stride_d, int32_t stride_h, int32_t stride_w, + int32_t pad_d, int32_t pad_h, int32_t pad_w, + triton::dnn::conv::type ty, + torch::Tensor torcha, torch::Tensor torchb + ) { + // Configuration + triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty); + // Allocate output + std::vector c_shapes = configuration.c_shapes(); + torch::Tensor torchc; + if(ty == triton::dnn::conv::WGRAD) + torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda(); + else + torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda(); + // Wrap CUDA handles + c10::DeviceIndex device = torchc.storage().device().index(); + triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false); + triton::driver::stream* stream = &sstream; + triton::driver::context* ctx = stream->context(); + triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); + triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); + triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); + stream->synchronize(); + // Create JIT + triton::jit jit(ctx); + std::string src = configuration.src(); + jit.add_module("conv", src.c_str(), configuration.default_params()); + triton::driver::kernel* kernel = jit.get_function("conv"); + triton::jit::launch_information info = jit.get_launch_info("conv"); + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + // initialize constant memory + if(ty != triton::dnn::conv::WGRAD){ + std::vector h_delta; + std::vector h_masks; + configuration.build_deltas(h_delta); + configuration.build_masks(h_masks); + triton::driver::buffer* delta = jit.get_buffer("delta"); + triton::driver::buffer* masks = jit.get_buffer("masks"); + stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); + stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); } - - /* Masks */ - size_t Ms0 = Nlut; - size_t Ms1 = 2*pad_w + 1; - size_t Ms2 = 2*pad_h + 1; - size_t Ms3 = 2*pad_d + 1; - - for(size_t pd = 0; pd < Ms3; ++pd) - for(size_t ph = 0; ph < Ms2; ++ph) - for(size_t pw = 0; pw < Ms1; ++pw){ - int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; - for(size_t i = 0; i < Ms0; ++i){ - int32_t t, r, s; - int32_t mask = 0x0; - for(size_t j = 0; j < TK; ++j){ - std::tie(t, r, s) = unpack((i + j) % F); - bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d); - bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h); - bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w); - mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; - } - masks_ptr[i] = mask; - } - } - for(size_t i = 0; i < Nlut; ++i) - masks[i] = 0x0; + // launch info + unsigned nthreads = info.num_threads; + std::array grid = configuration.get_grid(TM, TN); + configuration.set_arg(kernel, &a, &b, &c); + stream->synchronize(); + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); + return torchc; } -torch::Tensor conv_forward( +torch::Tensor conv_fprop( const torch::Tensor data, const torch::Tensor weight) { // Check @@ -176,6 +75,7 @@ torch::Tensor conv_forward( // Unpack data shapes const int32_t B = data.size(0); const int32_t Ci = data.size(1); + const int32_t D = 1; const int32_t H = data.size(2); const int32_t W = data.size(3); // Unpack weight shapes @@ -184,109 +84,76 @@ torch::Tensor conv_forward( const int32_t R = weight.size(1); const int32_t S = weight.size(2); const int32_t NF = weight.size(3); - // Conv parameters - int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; - int32_t pad_d = 0, pad_h = 0, pad_w = 0; - int32_t stride_h = 1, stride_w = 1; - // Output shapes - int32_t P = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h; - int32_t Q = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w; - // Allocate output + // Configuration + const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t pad_d = 0, pad_h = 1, pad_w = 1; + // Check AT_CHECK(Ci == Cf, "Number of channels in data and weights must match"); - torch::Tensor output = torch::empty({B, NF, P, Q}, torch::kFloat).cuda(); - // Wrap CUDA handles - c10::DeviceIndex device = output.storage().device().index(); - triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false); - triton::driver::stream* stream = &sstream; - triton::driver::context* ctx = stream->context(); - triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false); - triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false); - triton::driver::cu_buffer a(ctx, (CUdeviceptr)output.storage().data(), false); - // Create JIT - triton::jit jit(ctx); - std::vector params = { - 16, 2, 64, - 32, 2, 64, - 16, 8, 2, 2, - 8, 1, 8, - 4 - }; - jit.add_module("conv", src, params); - triton::driver::kernel* kernel = jit.get_function("conv"); - triton::jit::launch_information info = jit.get_launch_info("conv"); - // launch info - unsigned TM = info.global_range_size[0]; - unsigned TN = info.global_range_size[1]; - unsigned TK = jit.get_int("TK"); - // initialize constant memory - int FS = T*R*S; - int nlut = (TK + FS - 1) / FS * FS; - std::vector h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut); - std::vector h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut); - // memory stride for images - int32_t stride_i_w = 1; - int32_t stride_i_h = W*stride_i_w; - int32_t stride_i_d = H*stride_i_h; - int32_t stride_i_c = 1*stride_i_d; - int32_t stride_i_n = Ci*stride_i_c; - // memory stride for activations - int32_t stride_o_q = 1; - int32_t stride_o_p = Q*stride_o_q; - int32_t stride_o_m = P*stride_o_p; - int32_t stride_o_k = 1*stride_o_m; - int32_t stride_o_n = NF*stride_o_k; - build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks); - // equivalent matmul dimensions - int32_t M = B*P*Q; - int32_t N = NF; - int32_t K = Ci*R*S; - triton::driver::buffer* delta = jit.get_buffer("delta"); - triton::driver::buffer* masks = jit.get_buffer("masks"); - stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); - stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); - // launch info - unsigned nthreads = info.num_threads; - std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; - // fast bounds-checking - unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1; - unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1; - unsigned lastk = TK - 1; - bool AT = false; - bool BT = true; - unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk; - unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk; - int32_t bound = std::max(1, std::max(K - last_safe_a, K - last_safe_b)); - // set arguments - kernel->setArg(0, *d.cu()); - kernel->setArg(1, *w.cu()); - kernel->setArg(2, *a.cu()); - kernel->setArg(3, M); - kernel->setArg(4, N); - kernel->setArg(5, K); - kernel->setArg(6, B); - kernel->setArg(7, H); - kernel->setArg(8, W); - kernel->setArg(9, NF); - kernel->setArg(10, P); - kernel->setArg(11, Q); - kernel->setArg(12, Ci); - kernel->setArg(13, R); - kernel->setArg(14, S); - kernel->setArg(15, stride_i_n); - kernel->setArg(16, stride_i_c); - kernel->setArg(17, stride_i_h); - kernel->setArg(18, stride_i_w); - kernel->setArg(19, stride_o_n); - kernel->setArg(20, stride_o_k); - kernel->setArg(21, stride_o_p); - kernel->setArg(22, stride_o_q); - kernel->setArg(23, pad_h); - kernel->setArg(24, pad_w); - kernel->setArg(25, bound); -// // dry run - stream->enqueue(kernel, grid, {nthreads, 1, 1}); - return output; + return conv_common(B, Ci, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::FPROP, data, weight); +} + +torch::Tensor conv_bprop( + const torch::Tensor derror, + const torch::Tensor weight){ + // Check + CHECK_INPUT(derror); + CHECK_INPUT(weight); + // Unpack data shapes + const int32_t B = derror.size(0); + const int32_t Ki = derror.size(1); + const int32_t M = 1; + const int32_t P = derror.size(2); + const int32_t Q = derror.size(3); + // Unpack weight shapes + const int32_t C = weight.size(0); + const int32_t T = 1; + const int32_t R = weight.size(1); + const int32_t S = weight.size(2); + const int32_t Kw = weight.size(3); + // Compute M, P, Q + const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; + const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t pad_d = 0, pad_h = 1, pad_w = 1; + const int32_t D = M*stride_d + T - 1 - 2*pad_d + stride_d - 1 / upsample_d; + const int32_t H = P*stride_d + R - 1 - 2*pad_h + stride_h - 1 / upsample_h; + const int32_t W = Q*stride_d + S - 1 - 2*pad_w + stride_w - 1 / upsample_w; + // Check + AT_CHECK(Ki == Kw, "Number of channels in error and weights must match"); + return conv_common(B, C, D, H, W, T, R, S, Kw, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::BPROP, derror, weight); +} + +torch::Tensor conv_wgrad( + const torch::Tensor data, + const torch::Tensor derror + ){ + // Check + CHECK_INPUT(data); + CHECK_INPUT(derror); + // Unpack data shapes + const int32_t Ba = derror.size(0); + const int32_t C = derror.size(1); + const int32_t D = 1; + const int32_t H = derror.size(2); + const int32_t W = derror.size(3); + // Unpack error shapes + const int32_t Bb = derror.size(0); + const int32_t K = derror.size(1); + const int32_t M = 1; + const int32_t P = derror.size(2); + const int32_t Q = derror.size(3); + // Compute M, P, Q + const int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; + const int32_t stride_d = 1, stride_h = 1, stride_w = 1; + const int32_t pad_d = 0, pad_h = 1, pad_w = 1; + const int32_t T = (D - M*stride_d + 1 + 2*pad_d - stride_d + 1)*upsample_d; + const int32_t R = (H - P*stride_h + 1 + 2*pad_h - stride_h + 1)*upsample_h; + const int32_t S = (W - Q*stride_w + 1 + 2*pad_w - stride_w + 1)*upsample_w; + // Check + AT_CHECK(Ba == Bb, "Number of channels in error and weights must match"); + return conv_common(Ba, C, D, H, W, T, R, S, K, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, triton::dnn::conv::WGRAD, data, derror); } static auto registry = - torch::jit::RegisterOperators("triton::conv_forward", &conv_forward); + torch::jit::RegisterOperators("triton::conv_fprop", &conv_fprop) + .op("triton::conv_bprop", &conv_bprop) + .op("triton::conv_wgrad", &conv_wgrad); diff --git a/examples/python/pytorch/main.py b/examples/python/pytorch/main.py index d4b11e316..c0568f8b4 100644 --- a/examples/python/pytorch/main.py +++ b/examples/python/pytorch/main.py @@ -1,9 +1,48 @@ import torch -from torch.autograd import Variable +torch.manual_seed(0) +class TritonConv(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + output = torch.ops.triton.conv_fprop(input, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input = grad_weight = None + if ctx.needs_input_grad[0]: + grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight) + if ctx.needs_input_grad[1]: + grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous()) + return grad_input, grad_weight + + torch.ops.load_library("/home/philippe/Development/triton/build/examples/python/pytorch/libtorch_triton.so") -d = torch.empty(64, 64, 64, 64).uniform_(0, 1).cuda() -w = torch.empty(64, 3, 3, 64).uniform_(0, 1).cuda() -a = torch.ops.triton.conv_forward(d, w) -print(a) +x = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) +w = torch.autograd.Variable(torch.randn(64, 3, 3, 64).cuda(), requires_grad=True) +cuw = torch.autograd.Variable(w.permute(3,0,1,2).cuda(), requires_grad=True) +y_target = torch.autograd.Variable(torch.randn(16, 64, 8, 8).cuda(), requires_grad=True) + +def run(x, w, conv): + y = conv(x, w) + loss = (y - y_target).norm(2) + loss.backward() + return loss, y.clone(), x.grad.clone(), w.grad.clone() + +ttyloss, tty, ttdx, ttdw = run(x, w, TritonConv.apply) +x.grad.zero_() +w.grad.zero_() +culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1)) + +print((tty - cuy).norm(2)) +print((ttdx - cudx).norm(2)) +print((ttdw.permute(3,0,1,2) - cudw).norm(2)) +#print(ttdx) +#print(cudx) +#print(ttdw) +#print(cudw) +#print((ttdw.permute(3,0,1,2) - cudw).norm(2)) diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index c29bb925b..b2cbefe0f 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -20,12 +20,12 @@ public: conv(int B, int NC, int D, int H, int W, int T, int R, int S, int NF, - int upsample_d, int upsample_h, int upsample_w, + int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, type ty = FPROP) : NB_(B), NC_(NC), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), NF_(NF), - upsample_d_(upsample_d), upsample_h_(upsample_h), upsample_w_(upsample_w), - stride_d_(1), stride_h_(1), stride_w_(1), + stride_d_(stride_d), stride_h_(stride_h), stride_w_(stride_w), + upsample_d_(1), upsample_h_(1), upsample_w_(1), pad_d_(pad_d), pad_h_(pad_h), pad_w_(pad_w), ty_(ty) { @@ -93,6 +93,10 @@ public: 1, std::multiplies()); } + std::vector c_shapes() { + return shapes_c_; + } + void build_deltas(std::vector& deltas){ if(ty_ == WGRAD) throw std::runtime_error("no look-up table necessary for wgrad"); @@ -120,6 +124,7 @@ public: int32_t c = ctrs / Fs_; int32_t t, r, s; std::tie(t, r, s) = unpack(ctrs % Fs_); + // next indices int32_t nextctrs = ctrs + TK_; int32_t nextc = nextctrs / Fs_; @@ -223,6 +228,43 @@ public: std::string xprop() { + + std::string declare_pb; + if(ty_ == FPROP){ + declare_pb = R"( + fp32* pb[TN, TK] = b + rkb[newaxis, :]*ldb_s + rb0[:, newaxis]; + )"; + } + else{ + declare_pb = R"( + fp32* pb_base[TN, TK] = b + rb0[:, newaxis]*ldb_c; + int32 rbk[TK] = rkb / (BH*BW); + int32 rbrs[TK] = rkb % (BH*BW); + int32 rbs[TK] = BW - 1 - rbrs % BW; + int32 rbr[TK] = BH - 1 - rbrs / BW; + int32 rb1[TK] = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s; + fp32* pb[TN, TK] = pb_base + rb1[newaxis, :]; + )"; + } + std::string increment_pb; + if(ty_ == FPROP){ + increment_pb = R"( + pb = pb + TK*ldb_s; + )"; + } + else{ + increment_pb = R"( + rbrs = rbrs + TK; + rkb = rkb + TK; + rbk = rkb / (BH*BW); + rbrs = rkb % (BH*BW); + rbs = BW - 1 - rbrs % BW; + rbr = BH - 1 - rbrs / BW; + rb1 = rbk*ldb_k + rbr*ldb_r + rbs*ldb_s; + pb = pb_base + rb1[newaxis, :]; + )"; + } + std::string res = R"( const tunable int32 TM = {16, 32, 64}; @@ -246,7 +288,7 @@ public: int32 rxa[TM] = get_global_range[TM](0); int32 rb0[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; - int32 rb1[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; fp32 C[TM, TN] = 0; int32 rabh[TM] = rxa / CW; int32 raw[TM] = rxa % CW - pad_w; @@ -258,8 +300,8 @@ public: int32 rac[TK] = racr / BH; int32 rar[TK] = racr % BH; int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; - fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; - fp32* pb[TN, TK] = b + rb1[newaxis, :]*ldb_s + rb0[:, newaxis]; + fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)" + + declare_pb + R"( __constant__ int32* pincd[TK] = delta + rka; __constant__ int32* pd[TK] = delta + BH*BW + rka; int32 d[TK] = *pd; @@ -276,8 +318,8 @@ public: fp32 b[TN, TK] = *pb; for(int32 k = K; k > 0; k = k - TK){ C = dot(a, trans(b), C); - pb = pb + TK*ldb_s; - pa = pa + d[newaxis, :]; + pa = pa + d[newaxis, :];)" + + increment_pb + R"( b = *pb; pd = pd + incd; pincd = pincd + incd; @@ -288,6 +330,7 @@ public: incm = *pincm; checka0 = *pm; checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + checka = checka && (k > TK); a = checka ? *pa : 0; } int32 rxc[TM] = get_global_range[TM](0); @@ -379,7 +422,7 @@ public: { IN_DTYPE acc; for(int32_t n = 0; n < shapes_c_[0]; ++n) - for(int32_t k = 0; k < shapes_c_[1] ; ++k) + for(int32_t cf = 0; cf < shapes_c_[1] ; ++cf) for(int32_t cd = 0 ; cd < shapes_c_[2]; ++cd) for(int32_t ch = 0 ; ch < shapes_c_[3]; ++ch) for(int32_t cw = 0; cw < shapes_c_[4]; ++cw) @@ -388,7 +431,7 @@ public: int32_t d = cd*stride_d_ - pad_d_; int32_t h = ch*stride_h_ - pad_h_; int32_t w = cw*stride_w_ - pad_w_; - for(int32_t c = 0; c < shapes_b_[0]; ++c) + for(int32_t ac = 0; ac < shapes_a_[1]; ++ac) for(int32_t bd = 0; bd < shapes_b_[1]; ++bd) for(int32_t bh = 0; bh < shapes_b_[2]; ++bh) for(int32_t bw = 0; bw < shapes_b_[3]; ++bw){ @@ -400,11 +443,19 @@ public: aw >= 0 && aw < shapes_a_[4]); IN_DTYPE a = 0; if(in_bounds) - a = A[n*ld_a_[0] + c*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; - IN_DTYPE b = B[c*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + k*ld_b_[4]]; + a = A[n*ld_a_[0] + ac*ld_a_[1] + ad*ld_a_[2] + ah*ld_a_[3] + aw*ld_a_[4]]; + IN_DTYPE b; + if(ty_==FPROP) + b = B[ac*ld_b_[0] + bd*ld_b_[1] + bh*ld_b_[2] + bw*ld_b_[3] + cf*ld_b_[4]]; + else{ + int32_t bdd = bd; + int32_t bhh = bh; + int32_t bww = bw; + b = B[cf*ld_b_[0] + bdd*ld_b_[1] + bhh*ld_b_[2] + bww*ld_b_[3] + ac*ld_b_[4]]; + } acc = std::fma(a, b, acc); } - C[n*ld_c_[0] + k*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc; + C[n*ld_c_[0] + cf*ld_c_[1] + cd*ld_c_[2] + ch*ld_c_[3] + cw*ld_c_[4]] = acc; } } diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index a2a389cbd..86b1f2dc1 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -193,12 +193,20 @@ public: static cudnnStatus_t cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc, int pad_h, int pad_w, int u, int v, int upscalex, int upscaley, cudnnConvolutionMode_t mode); static cudnnStatus_t cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, int arrayLength, const int padA[], const int filterStrideA[], const int upscaleA[], cudnnConvolutionMode_t mode, cudnnDataType_t dataType); static cudnnStatus_t cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, const int windowDimA[], const int paddingA[], const int strideA[]); + static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId); + static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y); + // pooling + static cudnnStatus_t cudnnPoolingForward(cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y); + // forward static cudnnStatus_t cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionFwdAlgo_t *algo); static cudnnStatus_t cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, size_t *sizeInBytes); static cudnnStatus_t cudnnConvolutionForward(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const cudnnFilterDescriptor_t wDesc, const void *w, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, void *workSpace, size_t workSpaceSizeInBytes, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y); - static cudnnStatus_t cudnnPoolingForward(cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y); - static cudnnStatus_t cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId); - static cudnnStatus_t cudnnTransformTensor(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, const cudnnTensorDescriptor_t yDesc, void *y); + // backward data + static cudnnStatus_t cudnnConvolutionBackwardData(cudnnHandle_t handle, const void *alpha, const cudnnFilterDescriptor_t wDesc, const void *w, const cudnnTensorDescriptor_t dyDesc, const void *dy, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionBwdDataAlgo_t algo, void* workSpace, size_t workSpaceSizeInBytes, const void* beta, const cudnnTensorDescriptor_t dxDesc, void *dx); + static cudnnStatus_t cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc,const cudnnTensorDescriptor_t dyDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionBwdDataAlgo_t* algo); + // backward filter + static cudnnStatus_t cudnnConvolutionBackwardFilter(cudnnHandle_t handle, const void *alpha, const cudnnTensorDescriptor_t xDesc,const void *x, const cudnnTensorDescriptor_t dyDesc,const void *dy, const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionBwdFilterAlgo_t algo,void* workSpace, size_t workSpaceSizeInBytes, const void* beta, const cudnnFilterDescriptor_t dwDesc, void *dw); + static cudnnStatus_t cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t dyDesc, const cudnnConvolutionDescriptor_t convDesc, const cudnnFilterDescriptor_t dwDesc, cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, cudnnConvolutionBwdFilterAlgo_t* algo); // SPIR-V libraries static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);