[dnn/shift]: added stride to shift
This commit is contained in:
@@ -49,35 +49,35 @@ def run_conv():
|
||||
def blocksparse_matmul_grad(op, dy):
|
||||
shift_h = op.get_attr('shift_h')
|
||||
shift_w = op.get_attr('shift_w')
|
||||
stride_h = op.get_attr('stride_h')
|
||||
stride_w = op.get_attr('stride_w')
|
||||
x = op.inputs[0]
|
||||
w = op.inputs[1]
|
||||
dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w)
|
||||
dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w)
|
||||
dx = module.shift_conv_dx(dy, w, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w)
|
||||
dw = module.shift_conv_dw(dy, x, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w)
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 16, 1024, 8, 8
|
||||
R, S, F = 3, 3, 1024
|
||||
B, C, H, W = 16, 16, 4, 4
|
||||
R, S, F = 3, 3, 4
|
||||
stride_h, stride_w = 2, 2
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
b = tf.placeholder(tf.float32, shape=[C, F])
|
||||
hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
|
||||
#hshift_h = np.ones(C, dtype=np.int32)
|
||||
#hshift_w = np.ones(C, dtype=np.int32)
|
||||
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# Reference
|
||||
c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
|
||||
# feed values
|
||||
ha = np.random.rand(C, H, W, B)
|
||||
hb = np.random.rand(C, F)
|
||||
#ha = np.ones((C, H, W, B), dtype=np.int32)
|
||||
#hb = np.ones((C, F), dtype=np.int32)
|
||||
sess = tf.InteractiveSession()
|
||||
#grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B),
|
||||
# extra_feed_dict = {a: ha, b: hb})
|
||||
#dw_t, dw_n = grads[1]
|
||||
#dx_t, dx_n = grads[0]
|
||||
#print(np.max(np.abs(dw_t - dw_n)))
|
||||
#print(np.max(np.abs(dx_t - dx_n)))
|
||||
# test
|
||||
grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B),
|
||||
extra_feed_dict = {a: ha, b: hb})
|
||||
dw_t, dw_n = grads[1]
|
||||
dx_t, dx_n = grads[0]
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
@@ -127,4 +127,4 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_batchnorm()
|
||||
run_shift()
|
||||
|
@@ -34,6 +34,8 @@ public:
|
||||
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
context->GetAttr("shift_h", &h_shift_h_);
|
||||
context->GetAttr("shift_w", &h_shift_w_);
|
||||
context->GetAttr("stride_h", &stride_h_);
|
||||
context->GetAttr("stride_w", &stride_w_);
|
||||
R_ = 3;
|
||||
S_ = 3;
|
||||
}
|
||||
@@ -52,12 +54,12 @@ public:
|
||||
int64_t Hb = tf_b.dim_size(1);
|
||||
int64_t Wb = tf_b.dim_size(2);
|
||||
int64_t Bb = tf_b.dim_size(3);
|
||||
OP_REQUIRES(context, Ha == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height"));
|
||||
OP_REQUIRES(context, Wa == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width"));
|
||||
OP_REQUIRES(context, Ha*stride_h_ == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height"));
|
||||
OP_REQUIRES(context, Wa*stride_w_ == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width"));
|
||||
OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size"));
|
||||
H = Ha;
|
||||
W = Wa;
|
||||
B = Ba;
|
||||
H = Hb;
|
||||
W = Wb;
|
||||
B = Bb;
|
||||
}
|
||||
else {
|
||||
// shapes for a
|
||||
@@ -65,6 +67,10 @@ public:
|
||||
H = tf_a.dim_size(1);
|
||||
W = tf_a.dim_size(2);
|
||||
B = tf_a.dim_size(3);
|
||||
if(OP == triton::dnn::shift::BPROP){
|
||||
H *= stride_h_;
|
||||
W *= stride_w_;
|
||||
}
|
||||
// shapes for b
|
||||
int64_t Cb = tf_b.dim_size(0);
|
||||
F = tf_b.dim_size(1);
|
||||
@@ -104,7 +110,9 @@ public:
|
||||
if(m_config.find(key) == m_config.end())
|
||||
shift = m_config.emplace(key, new triton::dnn::shift(
|
||||
B, C, D, H, W, T, R_, S_, F,
|
||||
shift_h, shift_w, "fp32", "fp32", OP, has_bias))
|
||||
stride_h_, stride_w_,
|
||||
shift_h, shift_w,
|
||||
"fp32", "fp32", OP, has_bias))
|
||||
.first->second.get();
|
||||
else
|
||||
shift = m_config.at(key).get();
|
||||
@@ -125,7 +133,7 @@ public:
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
|
||||
// get JIT
|
||||
triton::jit* jit;
|
||||
bool autotune = true;
|
||||
bool autotune = false;
|
||||
if(m_jit.find(key) == m_jit.end()) {
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
@@ -171,6 +179,8 @@ public:
|
||||
private:
|
||||
Tensor h_shift_h_;
|
||||
Tensor h_shift_w_;
|
||||
int stride_h_;
|
||||
int stride_w_;
|
||||
int R_;
|
||||
int S_;
|
||||
};
|
||||
@@ -181,6 +191,8 @@ REGISTER_OP("ShiftConv")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
|
||||
@@ -189,6 +201,8 @@ REGISTER_OP("ShiftConvDx")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
|
||||
@@ -197,5 +211,7 @@ REGISTER_OP("ShiftConvDw")
|
||||
.Input("b: float32")
|
||||
.Attr("shift_h: tensor")
|
||||
.Attr("shift_w: tensor")
|
||||
.Attr("stride_h: int")
|
||||
.Attr("stride_w: int")
|
||||
.Output("c: float32");
|
||||
|
||||
|
@@ -52,6 +52,7 @@ public:
|
||||
shift(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
int stride_h, int stride_w,
|
||||
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
type ty = FPROP, bool bias = false);
|
||||
@@ -133,6 +134,10 @@ private:
|
||||
std::vector<int32_t> shapes_a_;
|
||||
std::vector<int32_t> shapes_b_;
|
||||
std::vector<int32_t> shapes_c_;
|
||||
// strides
|
||||
int32_t stride_d_;
|
||||
int32_t stride_h_;
|
||||
int32_t stride_w_;
|
||||
// memory strides
|
||||
std::vector<int32_t> ld_a_;
|
||||
std::vector<int32_t> ld_b_;
|
||||
|
@@ -17,6 +17,7 @@ shift::shift(int B, int C,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int F,
|
||||
int stride_h, int stride_w,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
@@ -24,6 +25,7 @@ shift::shift(int B, int C,
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
F_(F),
|
||||
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias) {
|
||||
@@ -33,17 +35,21 @@ shift::shift(int B, int C,
|
||||
// transpose
|
||||
AT_ = false;
|
||||
BT_ = true;
|
||||
// activation sizes
|
||||
CD_ = AD_ / stride_d_;
|
||||
CH_ = AH_ / stride_h_;
|
||||
CW_ = AW_ / stride_w_;
|
||||
// equivalent matmul
|
||||
M_ = B_*AH_*AW_;
|
||||
M_ = B_*CH_*CW_;
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// shapes
|
||||
// input layout: C, H, W, B
|
||||
// filter layout: C, F
|
||||
// output layout: F, H, W, B
|
||||
shapes_a_ = {C, H, W, B};
|
||||
shapes_a_ = {C, AH_, AW_, B};
|
||||
shapes_b_ = {C, F};
|
||||
shapes_c_ = {F, H, W, B};
|
||||
shapes_c_ = {F, CH_, CW_, B};
|
||||
if(ty_ == WGRAD){
|
||||
shapes_b_.swap(shapes_c_);
|
||||
shapes_a_.swap(shapes_b_);
|
||||
@@ -51,14 +57,14 @@ shift::shift(int B, int C,
|
||||
BT_ = false;
|
||||
M_ = F_;
|
||||
N_ = C_;
|
||||
K_ = B_*AH_*AW_;
|
||||
K_ = B_*CH_*CW_;
|
||||
}
|
||||
if(ty_ == BPROP){
|
||||
shapes_a_.swap(shapes_c_);
|
||||
AT_ = false;
|
||||
BT_ = false;
|
||||
K_ = F_;
|
||||
M_ = B_*AH_*AW_;
|
||||
M_ = B_*CH_*CW_;
|
||||
N_ = C_;
|
||||
}
|
||||
// memory strides
|
||||
@@ -133,13 +139,15 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, lda);
|
||||
kernel->setArg(7, ldb);
|
||||
kernel->setArg(8, B_);
|
||||
kernel->setArg(9, AH_);
|
||||
kernel->setArg(10, AW_);
|
||||
kernel->setArg(11, BH_);
|
||||
kernel->setArg(12, BW_);
|
||||
kernel->setArg(6, stride_h_);
|
||||
kernel->setArg(7, stride_w_);
|
||||
kernel->setArg(8, lda);
|
||||
kernel->setArg(9, ldb);
|
||||
kernel->setArg(10, B_);
|
||||
kernel->setArg(11, AH_);
|
||||
kernel->setArg(12, AW_);
|
||||
kernel->setArg(13, BH_);
|
||||
kernel->setArg(14, BW_);
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
if(ty_ == BPROP)
|
||||
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
|
||||
@@ -188,6 +196,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
restrict read_only align(16) )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 lda, int32 ldb,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
@@ -200,9 +209,9 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
|
||||
if(ty_ == FPROP){
|
||||
os << R"(
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW;
|
||||
int32 raw[TM] = (rawhc % AW)*stride_w;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH;
|
||||
int32 rah[TM] = (rahc % AH)*stride_h;
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
|
||||
@@ -227,9 +236,9 @@ if(ty_ == WGRAD){
|
||||
if(ty_ == WGRAD){
|
||||
os << R"(
|
||||
int32 rbwhc[TK] = rkb / ABS;
|
||||
int32 rbw[TK] = rbwhc % AW;
|
||||
int32 rbw[TK] = (rbwhc % AW)*stride_w;
|
||||
int32 rbhc[TK] = rbwhc / AW;
|
||||
int32 rbh[TK] = rbhc % AH;
|
||||
int32 rbh[TK] = (rbhc % AH)*stride_h;
|
||||
int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||
int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
@@ -266,9 +275,9 @@ if(ty_ == WGRAD){
|
||||
pb = pb + TK)" << ldb0 << R"(;
|
||||
rkb = rkb + TK;
|
||||
rbwhc = rkb / ABS;
|
||||
rbw = rbwhc % AW;
|
||||
rbw = (rbwhc % AW)*stride_w;
|
||||
rbhc = rbwhc / AW;
|
||||
rbh = rbhc % AH;
|
||||
rbh = (rbhc % AH)*stride_h;
|
||||
interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));
|
||||
interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));
|
||||
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
@@ -292,9 +301,9 @@ else{
|
||||
if(ty_ == BPROP){
|
||||
os << R"(
|
||||
int32 rcwhc[TM] = rxc / ABS;
|
||||
int32 rcw[TM] = rcwhc % AW;
|
||||
int32 rcw[TM] = (rcwhc % AW)*stride_w;
|
||||
int32 rchc[TM] = rcwhc / AW;
|
||||
int32 rch[TM] = rchc % AH;
|
||||
int32 rch[TM] = (rchc % AH)*stride_h;
|
||||
int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
|
||||
int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
|
||||
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
|
Reference in New Issue
Block a user