diff --git a/examples/python/pytorch/run.py b/examples/python/pytorch/run.py index db7a1b152..906468fe6 100644 --- a/examples/python/pytorch/run.py +++ b/examples/python/pytorch/run.py @@ -33,7 +33,7 @@ class Shift(nn.Module): self.channels = in_channels self.kernel_size = kernel_size if kernel_size == 3: - p = torch.Tensor([0., 1., 0.]) + p = torch.Tensor([0.3, 0.4, 0.3]) elif kernel_size == 5: p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1]) elif kernel_size == 7: @@ -68,25 +68,24 @@ def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilati class NetReference(nn.Module): def __init__(self): super(NetReference, self).__init__() - #self.conv1 = ShiftConv2d(1, 32, 3, 2) - self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) + self.conv1 = ShiftConv2d(1, 32, 3, 2) + #self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) self.bn1 = nn.BatchNorm2d(32) - #self.conv2a = ShiftConv2d(32, 32, 3, 1) - self.conv2b = triton.ShiftConv2d(32, 32, 3, 2) - #self.conv2b = ShiftConv2d(32, 32, 3, 2) + #self.conv2 = triton.ShiftConv2d(32, 32, 3, 2) + self.conv2 = ShiftConv2d(32, 32, 3, 2) self.bn2 = nn.BatchNorm2d(32) self.fc1 = nn.Linear(32*7*7, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): - x = x.permute(1, 2, 3, 0).contiguous() + #x = x.permute(1, 2, 3, 0).contiguous() x = self.conv1(x) - x = x.permute(3, 0, 1, 2).contiguous() + #x = x.permute(3, 0, 1, 2).contiguous() x = self.bn1(x) x = F.relu(x) - x = x.permute(1, 2, 3, 0).contiguous() - x = self.conv2b(x) - x = x.permute(3, 0, 1, 2).contiguous() + #x = x.permute(1, 2, 3, 0).contiguous() + x = self.conv2(x) + #x = x.permute(3, 0, 1, 2).contiguous() x = self.bn2(x) x = F.relu(x) x = x.view(-1, 32*7*7) diff --git a/examples/python/pytorch/triton.py b/examples/python/pytorch/triton.py index 7f45daef0..efeade389 100644 --- a/examples/python/pytorch/triton.py +++ b/examples/python/pytorch/triton.py @@ -152,7 +152,7 @@ class _ShiftConvNd(torch.nn.Module): def make_shift(self, kernel_size): if kernel_size == 3: - p = torch.Tensor([0., 1., 0.]) + p = torch.Tensor([0.3, 0.4, 0.3]) elif kernel_size == 5: p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1]) elif kernel_size == 7: diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index c44e0edab..ee1322d5c 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -58,24 +58,29 @@ def blocksparse_matmul_grad(op, dy): return (dx, dw) def run_shift(): - B, C, H, W = 16, 1, 4, 4 - R, S, F = 3, 3, 32 + B, C, H, W = 16, 16, 4, 4 + R, S, F = 3, 3, 16 stride_h, stride_w = 2, 2 np.random.seed(2) a = tf.placeholder(tf.float32, shape=[C, H, W, B]) b = tf.placeholder(tf.float32, shape=[C, F]) - hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) - hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) + #hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) + #hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) + hshift_h = np.zeros(C, dtype=np.int32) + hshift_w = np.zeros(C, dtype=np.int32) c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w)) # feed values - ha = np.ones((C, H, W, B), dtype=np.float32) - hb = np.ones((C, F), dtype=np.float32) + ha = np.random.rand(C, H, W, B) + hb = np.random.rand(C, F) + #ha = np.ones((C, H, W, B), dtype=np.float32) + #hb = np.ones((C, F), dtype=np.float32) sess = tf.InteractiveSession() # test grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B), extra_feed_dict = {a: ha, b: hb}) dw_t, dw_n = grads[1] dx_t, dx_n = grads[0] + print(dw_t, dw_n) print(np.max(np.abs(dw_t - dw_n))) print(np.max(np.abs(dx_t - dx_n))) # Run diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index cbae500ed..e537cc563 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -139,6 +139,13 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, const std::vector &ranges, size_t nthreads) { int32_t lda = AT_ ? K_ : M_; int32_t ldb = BT_ ? N_ : K_; + int32_t ldc = M_; + if(ty_ == FPROP) + lda *= stride_h_*stride_w_; + if(ty_ == WGRAD) + ldb *= stride_h_*stride_w_; + if(ty_ == BPROP) + ldc *= stride_h_*stride_w_; driver::buffer *a = args[0], *b = args[1], *c = args[2]; kernel->setArg(0, a); kernel->setArg(1, b); @@ -150,15 +157,18 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(7, stride_w_); kernel->setArg(8, lda); kernel->setArg(9, ldb); - kernel->setArg(10, B_); - kernel->setArg(11, AH_); - kernel->setArg(12, AW_); - kernel->setArg(13, BH_); - kernel->setArg(14, BW_); + kernel->setArg(10, ldc); + kernel->setArg(11, B_); + kernel->setArg(12, AH_); + kernel->setArg(13, AW_); + kernel->setArg(14, BH_); + kernel->setArg(15, BW_); + kernel->setArg(16, CH_); + kernel->setArg(17, CW_); unsigned TM = ranges[0], TN = ranges[1]; std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; if(ty_ == BPROP) - ((driver::cu_buffer*)c)->set_zero(stream, M_*N_*stride_h_*stride_w_*4); + ((driver::cu_buffer*)c)->set_zero(stream, ldc*N_*4); stream->enqueue(kernel, grid, {nthreads, 1, 1}); } @@ -205,22 +215,21 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, fp32 *c, int32 M, int32 N, int32 K, int32 stride_h, int32 stride_w, - int32 lda, int32 ldb, - int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { + int32 lda, int32 ldb, int32 ldc, + int32 NB, int32 AH, int32 AW, int32 BH, int32 BW, int32 CH, int32 CW) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 C[TM, TN] = 0; - int32 pad_h = AR / 2; - int32 pad_w = AS / 2;)"; + int32 pad_h = BH / 2; + int32 pad_w = BW / 2;)"; if(ty_ == FPROP){ os << R"( - int32 rawhc[TM] = rxa / ABS; - int32 rab[TM] = rxa % ABS; - int32 raw[TM] = (rawhc % AW)*stride_w; - int32 rahc[TM] = rawhc / AW; - int32 rah[TM] = (rahc % AH)*stride_h; + int32 rawh[TM] = rxa / NB; + int32 rab[TM] = rxa % NB; + int32 raw[TM] = (rawh % CW)*stride_w; + int32 rah[TM] = (rawh / CW)*stride_h; __constant__ int32* pd[TK] = delta + rka; multiple_of(4) int32 d[TK] = *pd; int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); @@ -229,43 +238,41 @@ if(ty_ == FPROP){ int32 inc_true[TM, TK] = d[newaxis, :]; int32 inc_false[TM, TK] = rka[newaxis, :] * lda; int32 inc[TM, TK] = interior ? inc_true : inc_false; - rxa = rab + raw*ABS + rah*ABS*AW; - int32 offa0[TM, TK] = rxa[:, newaxis];)"; + int32 offxa[TM] = rab + raw*NB + rah*NB*AW;)"; } else{ - os << " int32 offa0[" << AS << "] = rxa" << bca1 << lda1 << ";" << std::endl; + os << R"( + int32 offxa[TM] = rxa;)"; } if(ty_ == WGRAD){ os << R"( __constant__ int32* pd[TN] = delta + ryb; int32 d[TN] = *pd; int32 shift[TK, TN] = d[newaxis, :]; - int32 rbwhc[TK] = rkb / ABS; - int32 rbw[TK] = (rbwhc % AW)*stride_w; - int32 rbhc[TK] = rbwhc / AW; - int32 rbh[TK] = (rbhc % AH)*stride_h; - )"; -} - os << R"( - )" << a_ty_ << "* pa[" << AS << "] = a + offa0 + " << rka << bca0 << lda0 << R"(; - )" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(; - int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(; - int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(; - )" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0;)"; -if(ty_ == WGRAD){ - os << R"( - int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); - int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); - int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; - int32 inc[TK, TN] = interior ? shift : 0; - )" << b_ty_ << R"(* shifted_pb[TK, TN] = pb + inc; - )" << b_ty_ << R"( b[TK, TN] = checkb ? *shifted_pb : 0;)"; + int32 rbwh[TK] = rkb / NB; + int32 rbb[TK] = rkb % NB; + int32 rbw[TK] = (rbwh % CW)*stride_w; + int32 rbh[TK] = (rbwh / CW)*stride_h; + int32 offkb[TK] = rbb + rbw*NB + rbh*NB*AW; + int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); + int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); + int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; + int32 inc[TK, TN] = interior ? shift : 0; + )" << b_ty_ << "* pb_base[" << BS << "] = b + ryb" << bcb1 << ldb1 << R"(; + )" << b_ty_ << "* pb[" << BS << "] = pb_base + offkb[:, newaxis] + inc;"; } else{ os << R"( - )" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0;)"; + int32 offkb[TK] = rkb; + )" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << "offkb" << bcb0 << ldb0 << R"(; + )"; } os << R"( + )" << a_ty_ << "* pa[" << AS << "] = a + offxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(; + int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(; + int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(; + )" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0; + )" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0; for(int32 k = K; k > 0; k = k - TK){ C = dot()" << usea << "," << useb << R"(, C); int1 checka[)" << AS << R"(] = k > TK; @@ -287,18 +294,18 @@ else{ } if(ty_ == WGRAD){ os << R"( - pb = pb + TK)" << ldb0 << R"(; rkb = rkb + TK; - rbwhc = rkb / ABS; - rbw = (rbwhc % AW)*stride_w; - rbhc = rbwhc / AW; - rbh = (rbhc % AH)*stride_h; + rbwh = rkb / NB; + rbb = rkb % NB; + rbw = (rbwh % CW)*stride_w; + rbh = (rbwh / CW)*stride_h; + offkb = rbb + rbw*NB + rbh*NB*AW; interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h)); interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w)); interior = interiorh[:, newaxis] && interiorw[:, newaxis]; inc = interior ? shift : 0; - shifted_pb = pb + inc; - @checkb b = *shifted_pb;)"; + pb = pb_base + offkb[:, newaxis] + inc; + @checkb b = *pb;)"; } else{ os << R"( @@ -311,20 +318,20 @@ else{ int32 ryc[TN] = get_global_range[TN](1);)"; if(ty_ == BPROP){ os << R"( - int32 rcwhc[TM] = rxc / ABS; - int32 rcb[TM] = rxc % ABS; - int32 rcw[TM] = (rcwhc % AW)*stride_w; - int32 rchc[TM] = rcwhc / AW; - int32 rch[TM] = (rchc % AH)*stride_h; - rxc = rcb + rcw*ABS + rch*ABS*AW; - int32 offc0[TM, TN] = rxc[:, newaxis];)"; + int32 rcwh[TM] = rxc / NB; + int32 rcb[TM] = rxc % NB; + int32 rcw[TM] = (rcwh % CW) * stride_w; + int32 rch[TM] = (rcwh / CW) * stride_h; + int32 offxc[TM] = rcb + rcw*NB + rch*NB*AW; + )"; } else{ os << R"( - int32 offc0[TM, TN] = rxc[:, newaxis];)"; + int32 offxc[TM] = rxc; + )"; } os << R"(" - fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + offc0; + fp32* pc[TM, TN] = c + ryc[newaxis, :]*ldc + offxc[:, newaxis]; int1 checkc0[TM] = rxc < M; int1 checkc1[TN] = ryc < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";