[dnn/shift] many bugfixes in strided shift-conv
This commit is contained in:
@@ -33,7 +33,7 @@ class Shift(nn.Module):
|
||||
self.channels = in_channels
|
||||
self.kernel_size = kernel_size
|
||||
if kernel_size == 3:
|
||||
p = torch.Tensor([0., 1., 0.])
|
||||
p = torch.Tensor([0.3, 0.4, 0.3])
|
||||
elif kernel_size == 5:
|
||||
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
||||
elif kernel_size == 7:
|
||||
@@ -68,25 +68,24 @@ def ShiftConv2d(in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilati
|
||||
class NetReference(nn.Module):
|
||||
def __init__(self):
|
||||
super(NetReference, self).__init__()
|
||||
#self.conv1 = ShiftConv2d(1, 32, 3, 2)
|
||||
self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
||||
self.conv1 = ShiftConv2d(1, 32, 3, 2)
|
||||
#self.conv1 = triton.ShiftConv2d(1, 32, 3, 2)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
#self.conv2a = ShiftConv2d(32, 32, 3, 1)
|
||||
self.conv2b = triton.ShiftConv2d(32, 32, 3, 2)
|
||||
#self.conv2b = ShiftConv2d(32, 32, 3, 2)
|
||||
#self.conv2 = triton.ShiftConv2d(32, 32, 3, 2)
|
||||
self.conv2 = ShiftConv2d(32, 32, 3, 2)
|
||||
self.bn2 = nn.BatchNorm2d(32)
|
||||
self.fc1 = nn.Linear(32*7*7, 500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 2, 3, 0).contiguous()
|
||||
#x = x.permute(1, 2, 3, 0).contiguous()
|
||||
x = self.conv1(x)
|
||||
x = x.permute(3, 0, 1, 2).contiguous()
|
||||
#x = x.permute(3, 0, 1, 2).contiguous()
|
||||
x = self.bn1(x)
|
||||
x = F.relu(x)
|
||||
x = x.permute(1, 2, 3, 0).contiguous()
|
||||
x = self.conv2b(x)
|
||||
x = x.permute(3, 0, 1, 2).contiguous()
|
||||
#x = x.permute(1, 2, 3, 0).contiguous()
|
||||
x = self.conv2(x)
|
||||
#x = x.permute(3, 0, 1, 2).contiguous()
|
||||
x = self.bn2(x)
|
||||
x = F.relu(x)
|
||||
x = x.view(-1, 32*7*7)
|
||||
|
@@ -152,7 +152,7 @@ class _ShiftConvNd(torch.nn.Module):
|
||||
|
||||
def make_shift(self, kernel_size):
|
||||
if kernel_size == 3:
|
||||
p = torch.Tensor([0., 1., 0.])
|
||||
p = torch.Tensor([0.3, 0.4, 0.3])
|
||||
elif kernel_size == 5:
|
||||
p = torch.Tensor([0.1, 0.25, 0.3, 0.25, 0.1])
|
||||
elif kernel_size == 7:
|
||||
|
@@ -58,24 +58,29 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 16, 1, 4, 4
|
||||
R, S, F = 3, 3, 32
|
||||
B, C, H, W = 16, 16, 4, 4
|
||||
R, S, F = 3, 3, 16
|
||||
stride_h, stride_w = 2, 2
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
hshift_h = np.zeros(C, dtype=np.int32)
|
||||
hshift_w = np.zeros(C, dtype=np.int32)
|
||||
c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# feed values
|
||||
ha = np.ones((C, H, W, B), dtype=np.float32)
|
||||
hb = np.ones((C, F), dtype=np.float32)
|
||||
ha = np.random.rand(C, H, W, B)
|
||||
hb = np.random.rand(C, F)
|
||||
#ha = np.ones((C, H, W, B), dtype=np.float32)
|
||||
#hb = np.ones((C, F), dtype=np.float32)
|
||||
sess = tf.InteractiveSession()
|
||||
# test
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B),
|
||||
extra_feed_dict = {a: ha, b: hb})
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(dw_t, dw_n)
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
|
@@ -139,6 +139,13 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
const std::vector<unsigned> &ranges, size_t nthreads) {
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
int32_t ldc = M_;
|
||||
if(ty_ == FPROP)
|
||||
lda *= stride_h_*stride_w_;
|
||||
if(ty_ == WGRAD)
|
||||
ldb *= stride_h_*stride_w_;
|
||||
if(ty_ == BPROP)
|
||||
ldc *= stride_h_*stride_w_;
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
@@ -150,15 +157,18 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(7, stride_w_);
|
||||
kernel->setArg(8, lda);
|
||||
kernel->setArg(9, ldb);
|
||||
kernel->setArg(10, B_);
|
||||
kernel->setArg(11, AH_);
|
||||
kernel->setArg(12, AW_);
|
||||
kernel->setArg(13, BH_);
|
||||
kernel->setArg(14, BW_);
|
||||
kernel->setArg(10, ldc);
|
||||
kernel->setArg(11, B_);
|
||||
kernel->setArg(12, AH_);
|
||||
kernel->setArg(13, AW_);
|
||||
kernel->setArg(14, BH_);
|
||||
kernel->setArg(15, BW_);
|
||||
kernel->setArg(16, CH_);
|
||||
kernel->setArg(17, CW_);
|
||||
unsigned TM = ranges[0], TN = ranges[1];
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(ty_ == BPROP)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*stride_h_*stride_w_*4);
|
||||
((driver::cu_buffer*)c)->set_zero(stream, ldc*N_*4);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
@@ -205,22 +215,21 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 lda, int32 ldb,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 NB, int32 AH, int32 AW, int32 BH, int32 BW, int32 CH, int32 CW) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 pad_h = AR / 2;
|
||||
int32 pad_w = AS / 2;)";
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;)";
|
||||
if(ty_ == FPROP){
|
||||
os << R"(
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 rab[TM] = rxa % ABS;
|
||||
int32 raw[TM] = (rawhc % AW)*stride_w;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = (rahc % AH)*stride_h;
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = (rawh % CW)*stride_w;
|
||||
int32 rah[TM] = (rawh / CW)*stride_h;
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
@@ -229,43 +238,41 @@ if(ty_ == FPROP){
|
||||
int32 inc_true[TM, TK] = d[newaxis, :];
|
||||
int32 inc_false[TM, TK] = rka[newaxis, :] * lda;
|
||||
int32 inc[TM, TK] = interior ? inc_true : inc_false;
|
||||
rxa = rab + raw*ABS + rah*ABS*AW;
|
||||
int32 offa0[TM, TK] = rxa[:, newaxis];)";
|
||||
int32 offxa[TM] = rab + raw*NB + rah*NB*AW;)";
|
||||
}
|
||||
else{
|
||||
os << " int32 offa0[" << AS << "] = rxa" << bca1 << lda1 << ";" << std::endl;
|
||||
os << R"(
|
||||
int32 offxa[TM] = rxa;)";
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
__constant__ int32* pd[TN] = delta + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];
|
||||
int32 rbwhc[TK] = rkb / ABS;
|
||||
int32 rbw[TK] = (rbwhc % AW)*stride_w;
|
||||
int32 rbhc[TK] = rbwhc / AW;
|
||||
int32 rbh[TK] = (rbhc % AH)*stride_h;
|
||||
)";
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + offa0 + " << rka << bca0 << lda0 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << rkb << bcb0 << ldb0 << R"(;
|
||||
int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(;
|
||||
int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0;)";
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = (rbwh % CW)*stride_w;
|
||||
int32 rbh[TK] = (rbwh / CW)*stride_h;
|
||||
int32 offkb[TK] = rbb + rbw*NB + rbh*NB*AW;
|
||||
int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||
int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
int32 inc[TK, TN] = interior ? shift : 0;
|
||||
)" << b_ty_ << R"(* shifted_pb[TK, TN] = pb + inc;
|
||||
)" << b_ty_ << R"( b[TK, TN] = checkb ? *shifted_pb : 0;)";
|
||||
)" << b_ty_ << "* pb_base[" << BS << "] = b + ryb" << bcb1 << ldb1 << R"(;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = pb_base + offkb[:, newaxis] + inc;";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
)" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0;)";
|
||||
int32 offkb[TK] = rkb;
|
||||
)" << b_ty_ << "* pb[" << BS << "] = b + ryb" << bcb1 << ldb1 << " + " << "offkb" << bcb0 << ldb0 << R"(;
|
||||
)";
|
||||
}
|
||||
os << R"(
|
||||
)" << a_ty_ << "* pa[" << AS << "] = a + offxa" << bca1 << lda1 << " + " << rka << bca0 << lda0 << R"(;
|
||||
int1 checka[)" << AS << "] = (rka < K)" << bca0 << R"(;
|
||||
int1 checkb[)" << BS << "] = (rkb < K)" << bcb0 << R"(;
|
||||
)" << a_ty_ << " a[" << AS << R"(] = checka ? *pa : 0;
|
||||
)" << b_ty_ << " b[" << BS << R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot()" << usea << "," << useb << R"(, C);
|
||||
int1 checka[)" << AS << R"(] = k > TK;
|
||||
@@ -287,18 +294,18 @@ else{
|
||||
}
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
pb = pb + TK)" << ldb0 << R"(;
|
||||
rkb = rkb + TK;
|
||||
rbwhc = rkb / ABS;
|
||||
rbw = (rbwhc % AW)*stride_w;
|
||||
rbhc = rbwhc / AW;
|
||||
rbh = (rbhc % AH)*stride_h;
|
||||
rbwh = rkb / NB;
|
||||
rbb = rkb % NB;
|
||||
rbw = (rbwh % CW)*stride_w;
|
||||
rbh = (rbwh / CW)*stride_h;
|
||||
offkb = rbb + rbw*NB + rbh*NB*AW;
|
||||
interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||
interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
inc = interior ? shift : 0;
|
||||
shifted_pb = pb + inc;
|
||||
@checkb b = *shifted_pb;)";
|
||||
pb = pb_base + offkb[:, newaxis] + inc;
|
||||
@checkb b = *pb;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
@@ -311,20 +318,20 @@ else{
|
||||
int32 ryc[TN] = get_global_range[TN](1);)";
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
int32 rcwhc[TM] = rxc / ABS;
|
||||
int32 rcb[TM] = rxc % ABS;
|
||||
int32 rcw[TM] = (rcwhc % AW)*stride_w;
|
||||
int32 rchc[TM] = rcwhc / AW;
|
||||
int32 rch[TM] = (rchc % AH)*stride_h;
|
||||
rxc = rcb + rcw*ABS + rch*ABS*AW;
|
||||
int32 offc0[TM, TN] = rxc[:, newaxis];)";
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = (rcwh % CW) * stride_w;
|
||||
int32 rch[TM] = (rcwh / CW) * stride_h;
|
||||
int32 offxc[TM] = rcb + rcw*NB + rch*NB*AW;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
int32 offc0[TM, TN] = rxc[:, newaxis];)";
|
||||
int32 offxc[TM] = rxc;
|
||||
)";
|
||||
}
|
||||
os << R"("
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + offc0;
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*ldc + offxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
|
Reference in New Issue
Block a user