diff --git a/examples/python/pytorch/run.py b/examples/python/pytorch/run.py index 59f70d6c5..e7c10112c 100644 --- a/examples/python/pytorch/run.py +++ b/examples/python/pytorch/run.py @@ -65,55 +65,33 @@ def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilati ) -class NetReference(nn.Module): +class Net(nn.Module): def __init__(self): - super(NetReference, self).__init__() - #self.conv1 = ShiftConv2d(1, 32, 3, 2) - self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) - self.bn1 = nn.BatchNorm2d(32) - self.conv2 = triton.ShiftConv2d(32, 32, 3, 2) - #self.conv2 = ShiftConv2d(32, 32, 3, 2) - self.bn2 = nn.BatchNorm2d(32) - self.fc1 = nn.Linear(32*7*7, 500) + super(Net, self).__init__() + self.conv1 = ShiftConv2d(1, 32, 3, 1) + self.conv2 = ShiftConv2d(32, 128, 3, 1) + self.conv3 = ShiftConv2d(128, 128, 3, 2) + self.bn1 = nn.BatchNorm2d(128) + self.conv4 = ShiftConv2d(128, 256, 3, 2) + self.bn2 = nn.BatchNorm2d(256) + self.fc1 = nn.Linear(256*7*7, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) x = self.bn1(x) x = F.relu(x) - x = self.conv2(x) + x = self.conv4(x) x = self.bn2(x) x = F.relu(x) - x = x.view(-1, 32*7*7) + x = x.view(-1, 256*7*7) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) -class NetTriton(nn.Module): - def __init__(self): - super(NetTriton, self).__init__() - self.conv1 = triton.ShiftConv2d(1, 32, 3, 2) - self.bn1 = triton.BatchNorm2d(32) - self.conv2 = triton.ShiftConv2d(32, 64, 3, 2) - self.bn2 = triton.BatchNorm2d(64) - self.fc1 = nn.Linear(64*7*7, 500) - self.fc2 = nn.Linear(500, 10) - - def forward(self, x): - x = x.permute(1, 2, 3, 0).contiguous() - x = self.conv1(x) - x = self.bn1(x) - x = F.relu(x) - x = self.conv2(x) - x = self.bn2(x) - x = F.relu(x) - x = x.permute(3, 0, 1, 2).contiguous() - x = x.view(-1, 64*7*7) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - -Net = NetReference() +Net = Net() def train(args, model, device, train_loader, optimizer, epoch): model.train() diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 9de71d8a4..57850de9a 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -58,7 +58,7 @@ def blocksparse_matmul_grad(op, dy): return (dx, dw) def run_shift(): - B, C, H, W = 16, 16, 4, 4 + B, C, H, W = 16, 16, 2, 2 R, S, F = 3, 3, 32 stride_h, stride_w = 2, 2 np.random.seed(2) diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 1731508d0..57cb5ea0a 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -62,7 +62,7 @@ public: shift(int B, int NC, int D, int H, int W, int T, int R, int S, int NF, - int stride_h, int stride_w, + int stride_h, int stride_w, const int32_t* shift_h, const int32_t* shift_w, std::string a_ty = "fp32", std::string b_ty = "fp32", type ty = FPROP, bool bias = false, layout_t layout = CHWN); @@ -145,6 +145,8 @@ private: // shift values const int32_t* shift_h_; const int32_t* shift_w_; + bool shift_edge_h_; + bool shift_edge_w_; // look-up tables std::vector h_delta_a; std::vector h_delta_b; @@ -154,7 +156,7 @@ private: std::string a_ty_; std::string b_ty_; // convolution type - type ty_; + type op_; bool bias_; // transpose bool AT_; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index da3b5877d..aeaba72a4 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -23,8 +23,9 @@ shift::shift(int B, int C, stride_d_(1), stride_h_(stride_h), stride_w_(stride_w), shift_h_(shift_h), shift_w_(shift_w), a_ty_(a_ty), b_ty_(b_ty), - ty_(ty), bias_(bias), + op_(ty), bias_(bias), layout_(layout){ +// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; // max number of channels TK_ = 16; MAX_C_ = 8192 + TK_; @@ -51,6 +52,9 @@ shift::shift(int B, int C, default: throw std::runtime_error("unsupported input layout"); } + // Shift edge + shift_edge_h_ = (AH_ == stride_h_); + shift_edge_w_ = (AW_ == stride_w_); // B memory strides: [C, F] ldb_n_ = 1; ldb_h_ = 1; @@ -88,7 +92,7 @@ shift::shift(int B, int C, if(layout_ == NCHW) shapes_c_ = {B, F, CH_, CW_}; // Weight gradient - if(ty_ == WGRAD){ + if(op_ == WGRAD){ // b <-> c // b <-> a std::swap(ldb_n_, ldc_n_); @@ -106,7 +110,7 @@ shift::shift(int B, int C, shapes_c_ = {C, F}; } // Input gradient - if(ty_ == BPROP){ + if(op_ == BPROP){ // a <-> c std::swap(lda_n_, ldc_n_); std::swap(lda_w_, ldc_w_); @@ -128,10 +132,12 @@ base* shift::clone() const { void shift::build_delta_a() { h_delta_a.resize(MAX_C_); - if(ty_ == FPROP){ + auto shift_h = [&](int c) { return shift_edge_h_ ? std::max(0, shift_h_[c]) : shift_h_[c]; }; + auto shift_w = [&](int c) { return shift_edge_w_ ? std::max(0, shift_w_[c]) : shift_w_[c]; }; + if(op_ == FPROP){ // compute offset auto offset = [&](unsigned c) { - return c*lda_c_ + shift_h_[c]*lda_h_ + shift_w_[c]*lda_w_; + return c*lda_c_ + shift_h(c)*lda_h_ + shift_w(c)*lda_w_; }; // populate look-up table for(unsigned c = 0; c < TK_; c++) @@ -139,14 +145,14 @@ void shift::build_delta_a() { for(unsigned c = 0; c < C_; c++) h_delta_a[TK_ + c] = offset(c + TK_) - offset(c); } - if(ty_ == BPROP){ + if(op_ == BPROP){ for(unsigned c = 0; c < C_; c++){ - h_delta_a[c] = shift_h_[c]*ldc_h_ + shift_w_[c]*ldc_w_; + h_delta_a[c] = shift_h(c)*ldc_h_ + shift_w(c)*ldc_w_; } } - if(ty_ == WGRAD){ + if(op_ == WGRAD){ for(unsigned c = 0; c < C_; c++) - h_delta_a[c] = shift_h_[c]*ldb_h_ + shift_w_[c]*ldb_w_; + h_delta_a[c] = shift_h(c)*ldb_h_ + shift_w(c)*ldb_w_; } } @@ -167,10 +173,22 @@ bool shift::operator <(const base& other) const{ auto *y = dynamic_cast(&other); if(!y) return true; - return std::tie(B_, C_, AD_, AH_, AW_, BD_, BH_, BW_, F_, - shift_h_, shift_w_, ty_, bias_) - < std::tie(y->B_, y->C_, y->AD_, y->AH_, y->AW_, y->BD_, y->BH_, y->BW_, y->F_, - y->shift_h_, y->shift_w_, y->ty_, y->bias_); + return std::tie(B_, C_, F_, + AD_, AH_, AW_, + BD_, BH_, BW_, + CD_, CH_, CW_, + shift_h_, shift_w_, + stride_h_, stride_w_, + layout_, op_, + bias_) + < std::tie(y->B_, y->C_, y->F_, + y->AD_, y->AH_, y->AW_, + y->BD_, y->BH_, y->BW_, + y->CD_, y->CH_, y->CW_, + y->shift_h_, y->shift_w_, + y->stride_h_, y->stride_w_, + y->layout_, y->op_, + y->bias_); } void shift::init_impl(driver::stream *stream, driver::cu_module *module) { @@ -212,7 +230,7 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(26, CW_); unsigned TM = ranges[0], TN = ranges[1]; std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; - if(ty_ == BPROP) + if(op_ == BPROP) ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*4); stream->enqueue(kernel, grid, {nthreads, 1, 1}); } @@ -263,7 +281,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *A, int32 pad_w = BW / 2;)"; /* A offsets */ -if(ty_ == FPROP){ +if(op_ == FPROP){ os << R"( int32 rawh[TM] = rxa / NB; int32 rab[TM] = rxa % NB; @@ -274,13 +292,20 @@ if(ty_ == FPROP){ __constant__ int32* pd[TK] = delta_a + rka; multiple_of(4) int32 d[TK] = *pd; int32 offa_interior[TM, TK] = d[newaxis, :]; - int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; - int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); - int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); + int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;\n)"; + if(shift_edge_h_) + os << " int1 interiorh[TM] = 1;"; + else + os << " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));"; + if(shift_edge_w_) + os << " int1 interiorw[TM] = 1;"; + else + os << " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));"; + os << R"( int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis]; int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; } -if(ty_ == BPROP){ +if(op_ == BPROP){ os << R"( int32 rawh[TM] = rxa / NB; int32 rab[TM] = rxa % NB; @@ -290,12 +315,12 @@ if(ty_ == BPROP){ int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; } -if(ty_ == WGRAD && layout_ == CHWN){ +if(op_ == WGRAD && layout_ == CHWN){ os << R"( int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa1[TK, TM] = rka[:, newaxis];)"; } -if(ty_ == WGRAD && layout_ == NCHW){ +if(op_ == WGRAD && layout_ == NCHW){ os << R"( int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 rawh[TK] = rka / NB; @@ -307,17 +332,17 @@ if(ty_ == WGRAD && layout_ == NCHW){ } /* B offsets */ -if(ty_ == FPROP){ +if(op_ == FPROP){ os << R"( int32 offb0[TN, TK] = ryb[:, newaxis]; int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; } -if(ty_ == BPROP){ +if(op_ == BPROP){ os << R"( int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb1[TK, TN] = rkb[:, newaxis];)"; } -if(ty_ == WGRAD){ +if(op_ == WGRAD){ os << R"( __constant__ int32* pd[TN] = delta_a + ryb; int32 d[TN] = *pd; @@ -326,9 +351,16 @@ if(ty_ == WGRAD){ int32 rbb[TK] = rkb % NB; int32 rbw[TK] = (rbwh % CW)*stride_w; int32 rbh[TK] = (rbwh / CW)*stride_h; - int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); - int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); + int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)"; + if(shift_edge_h_) + os << " int1 interiorh[TK] = 1;\n"; + else + os << " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n"; + if(shift_edge_w_) + os << " int1 interiorw[TK] = 1;\n"; + else + os << " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));\n"; + os << R"( int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; int32 incb[TK, TN] = interior ? shift : 0; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; @@ -349,7 +381,7 @@ if(ty_ == WGRAD){ int1 checkb[)" << BS << R"(] = k > TK;)"; /* Increment A pointers */ -if(ty_ == FPROP){ +if(op_ == FPROP){ os << R"( pd = pd + TK; d = *pd; @@ -358,15 +390,15 @@ if(ty_ == FPROP){ int32 offa[TM, TK] = interior ? offa_interior : offa_exterior; pa = pa + offa;)"; } -if(ty_ == BPROP){ +if(op_ == BPROP){ os << R"( pa = pa + TK * lda_c;)"; } -if(ty_ == WGRAD && layout_ == CHWN){ +if(op_ == WGRAD && layout_ == CHWN){ os << R"( pa = pa + TK;)"; } -if(ty_ == WGRAD && layout_ == NCHW){ +if(op_ == WGRAD && layout_ == NCHW){ os << R"( rka = rka + TK; rawh = rka / NB; @@ -380,25 +412,32 @@ if(ty_ == WGRAD && layout_ == NCHW){ @checka a = *pa;)"; /* Increment B pointers */ -if(ty_ == WGRAD){ +if(op_ == WGRAD){ os << R"( rkb = rkb + TK; rbwh = rkb / NB; rbb = rkb % NB; rbw = (rbwh % CW)*stride_w; rbh = (rbwh / CW)*stride_h; - offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; - interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h)); - interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w)); + offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;\n)"; + if(shift_edge_h_) + os << " interiorh = 1;\n"; + else + os << " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n"; + if(shift_edge_w_) + os << " interiorw = 1;\n"; + else + os << " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));\n"; + os << R"( interior = interiorh[:, newaxis] && interiorw[:, newaxis]; incb = interior ? shift : 0; pb = B + offb0 + offkb[:, newaxis] + incb;)"; } -if(ty_ == FPROP){ +if(op_ == FPROP){ os << R"( pb = pb + TK * ldb_c;)"; } -if(ty_ == BPROP){ +if(op_ == BPROP){ os << R"( pb = pb + TK;)"; } @@ -409,7 +448,7 @@ if(ty_ == BPROP){ int32 ryc[TN] = get_global_range[TN](1);)"; /* C offsets */ -if(ty_ == BPROP){ +if(op_ == BPROP){ os << R"( int32 rcwh[TM] = rxc / NB; int32 rcb[TM] = rxc % NB; @@ -418,7 +457,7 @@ if(ty_ == BPROP){ int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h; )"; } -if(ty_ == FPROP){ +if(op_ == FPROP){ os << R"( int32 rcwh[TM] = rxc / NB; int32 rcb[TM] = rxc % NB; @@ -427,7 +466,7 @@ if(ty_ == FPROP){ int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h; )"; } -if(ty_ == WGRAD){ +if(op_ == WGRAD){ os << R"( int32 offxc[TM] = rxc; )"; @@ -437,10 +476,17 @@ if(ty_ == WGRAD){ int1 checkc0[TM] = rxc < M; int1 checkc1[TN] = ryc < N; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; -if(ty_ == BPROP){ +if(op_ == BPROP){ + os << "\n"; + if(shift_edge_h_) + os << " int1 interiorh[TM] = 1;\n"; + else + os << " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n"; + if(shift_edge_w_) + os << " int1 interiorw[TM] = 1;\n"; + else + os << " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));\n"; os << R"( - int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h)); - int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w)); int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; __constant__ int32* pd[TN] = delta_a + ryc; fp32* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];