[dnn/shift]: added stride to shift

This commit is contained in:
Philippe Tillet
2019-07-09 14:08:51 -07:00
parent cc41604784
commit 066ae338f1
4 changed files with 74 additions and 44 deletions

View File

@@ -49,35 +49,35 @@ def run_conv():
def blocksparse_matmul_grad(op, dy): def blocksparse_matmul_grad(op, dy):
shift_h = op.get_attr('shift_h') shift_h = op.get_attr('shift_h')
shift_w = op.get_attr('shift_w') shift_w = op.get_attr('shift_w')
stride_h = op.get_attr('stride_h')
stride_w = op.get_attr('stride_w')
x = op.inputs[0] x = op.inputs[0]
w = op.inputs[1] w = op.inputs[1]
dx = module.shift_conv_dx(dy, w, shift_h=shift_h, shift_w=shift_w) dx = module.shift_conv_dx(dy, w, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w)
dw = module.shift_conv_dw(dy, x, shift_h=shift_h, shift_w=shift_w) dw = module.shift_conv_dw(dy, x, stride_h=stride_h, stride_w=stride_w, shift_h=shift_h, shift_w=shift_w)
return (dx, dw) return (dx, dw)
def run_shift(): def run_shift():
B, C, H, W = 16, 1024, 8, 8 B, C, H, W = 16, 16, 4, 4
R, S, F = 3, 3, 1024 R, S, F = 3, 3, 4
stride_h, stride_w = 2, 2
np.random.seed(2) np.random.seed(2)
a = tf.placeholder(tf.float32, shape=[C, H, W, B]) a = tf.placeholder(tf.float32, shape=[C, H, W, B])
b = tf.placeholder(tf.float32, shape=[C, F]) b = tf.placeholder(tf.float32, shape=[C, F])
hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32) hshift_h = np.random.randint(- (R//2), R//2 + 1, size=C, dtype=np.int32)
hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32) hshift_w = np.random.randint(- (S//2), R//2 + 1, size=C, dtype=np.int32)
#hshift_h = np.ones(C, dtype=np.int32) c = module.shift_conv(a, b, stride_h=stride_h, stride_w=stride_w, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
#hshift_w = np.ones(C, dtype=np.int32) # feed values
c = module.shift_conv(a, b, shift_h=tf.make_tensor_proto(hshift_h), shift_w=tf.make_tensor_proto(hshift_w))
# Reference
ha = np.random.rand(C, H, W, B) ha = np.random.rand(C, H, W, B)
hb = np.random.rand(C, F) hb = np.random.rand(C, F)
#ha = np.ones((C, H, W, B), dtype=np.int32)
#hb = np.ones((C, F), dtype=np.int32)
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
#grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H, W, B), # test
# extra_feed_dict = {a: ha, b: hb}) grads = tf.test.compute_gradient([a, b], [(C, H, W, B), (C, F)], c, (F, H//stride_h, W//stride_w, B),
#dw_t, dw_n = grads[1] extra_feed_dict = {a: ha, b: hb})
#dx_t, dx_n = grads[0] dw_t, dw_n = grads[1]
#print(np.max(np.abs(dw_t - dw_n))) dx_t, dx_n = grads[0]
#print(np.max(np.abs(dx_t - dx_n))) print(np.max(np.abs(dw_t - dw_n)))
print(np.max(np.abs(dx_t - dx_n)))
# Run # Run
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha, result = sess.run([c], feed_dict = {a: ha,
@@ -127,4 +127,4 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n))) print(np.max(np.abs(db_t - db_n)))
run_batchnorm() run_shift()

View File

@@ -34,6 +34,8 @@ public:
explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) { explicit ShiftConvOp(OpKernelConstruction* context) : OpKernel(context) {
context->GetAttr("shift_h", &h_shift_h_); context->GetAttr("shift_h", &h_shift_h_);
context->GetAttr("shift_w", &h_shift_w_); context->GetAttr("shift_w", &h_shift_w_);
context->GetAttr("stride_h", &stride_h_);
context->GetAttr("stride_w", &stride_w_);
R_ = 3; R_ = 3;
S_ = 3; S_ = 3;
} }
@@ -52,12 +54,12 @@ public:
int64_t Hb = tf_b.dim_size(1); int64_t Hb = tf_b.dim_size(1);
int64_t Wb = tf_b.dim_size(2); int64_t Wb = tf_b.dim_size(2);
int64_t Bb = tf_b.dim_size(3); int64_t Bb = tf_b.dim_size(3);
OP_REQUIRES(context, Ha == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height")); OP_REQUIRES(context, Ha*stride_h_ == Hb, tensorflow::errors::InvalidArgument("operands must have the same image height"));
OP_REQUIRES(context, Wa == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width")); OP_REQUIRES(context, Wa*stride_w_ == Wb, tensorflow::errors::InvalidArgument("operands must have the same image width"));
OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size")); OP_REQUIRES(context, Ba == Bb, tensorflow::errors::InvalidArgument("operands must have the same batch size"));
H = Ha; H = Hb;
W = Wa; W = Wb;
B = Ba; B = Bb;
} }
else { else {
// shapes for a // shapes for a
@@ -65,6 +67,10 @@ public:
H = tf_a.dim_size(1); H = tf_a.dim_size(1);
W = tf_a.dim_size(2); W = tf_a.dim_size(2);
B = tf_a.dim_size(3); B = tf_a.dim_size(3);
if(OP == triton::dnn::shift::BPROP){
H *= stride_h_;
W *= stride_w_;
}
// shapes for b // shapes for b
int64_t Cb = tf_b.dim_size(0); int64_t Cb = tf_b.dim_size(0);
F = tf_b.dim_size(1); F = tf_b.dim_size(1);
@@ -104,7 +110,9 @@ public:
if(m_config.find(key) == m_config.end()) if(m_config.find(key) == m_config.end())
shift = m_config.emplace(key, new triton::dnn::shift( shift = m_config.emplace(key, new triton::dnn::shift(
B, C, D, H, W, T, R_, S_, F, B, C, D, H, W, T, R_, S_, F,
shift_h, shift_w, "fp32", "fp32", OP, has_bias)) stride_h_, stride_w_,
shift_h, shift_w,
"fp32", "fp32", OP, has_bias))
.first->second.get(); .first->second.get();
else else
shift = m_config.at(key).get(); shift = m_config.at(key).get();
@@ -125,7 +133,7 @@ public:
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<float>().data(), false);
// get JIT // get JIT
triton::jit* jit; triton::jit* jit;
bool autotune = true; bool autotune = false;
if(m_jit.find(key) == m_jit.end()) { if(m_jit.find(key) == m_jit.end()) {
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get(); jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::ostringstream oss; std::ostringstream oss;
@@ -171,6 +179,8 @@ public:
private: private:
Tensor h_shift_h_; Tensor h_shift_h_;
Tensor h_shift_w_; Tensor h_shift_w_;
int stride_h_;
int stride_w_;
int R_; int R_;
int S_; int S_;
}; };
@@ -181,6 +191,8 @@ REGISTER_OP("ShiftConv")
.Input("b: float32") .Input("b: float32")
.Attr("shift_h: tensor") .Attr("shift_h: tensor")
.Attr("shift_w: tensor") .Attr("shift_w: tensor")
.Attr("stride_h: int")
.Attr("stride_w: int")
.Output("c: float32"); .Output("c: float32");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>); REGISTER_KERNEL_BUILDER(Name("ShiftConvDx").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::BPROP>);
@@ -189,6 +201,8 @@ REGISTER_OP("ShiftConvDx")
.Input("b: float32") .Input("b: float32")
.Attr("shift_h: tensor") .Attr("shift_h: tensor")
.Attr("shift_w: tensor") .Attr("shift_w: tensor")
.Attr("stride_h: int")
.Attr("stride_w: int")
.Output("c: float32"); .Output("c: float32");
REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>); REGISTER_KERNEL_BUILDER(Name("ShiftConvDw").Device(DEVICE_GPU), ShiftConvOp<triton::dnn::shift::WGRAD>);
@@ -197,5 +211,7 @@ REGISTER_OP("ShiftConvDw")
.Input("b: float32") .Input("b: float32")
.Attr("shift_h: tensor") .Attr("shift_h: tensor")
.Attr("shift_w: tensor") .Attr("shift_w: tensor")
.Attr("stride_h: int")
.Attr("stride_w: int")
.Output("c: float32"); .Output("c: float32");

View File

@@ -52,6 +52,7 @@ public:
shift(int B, int NC, shift(int B, int NC,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int NF, int T, int R, int S, int NF,
int stride_h, int stride_w,
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w, const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32", std::string a_ty = "fp32", std::string b_ty = "fp32",
type ty = FPROP, bool bias = false); type ty = FPROP, bool bias = false);
@@ -133,6 +134,10 @@ private:
std::vector<int32_t> shapes_a_; std::vector<int32_t> shapes_a_;
std::vector<int32_t> shapes_b_; std::vector<int32_t> shapes_b_;
std::vector<int32_t> shapes_c_; std::vector<int32_t> shapes_c_;
// strides
int32_t stride_d_;
int32_t stride_h_;
int32_t stride_w_;
// memory strides // memory strides
std::vector<int32_t> ld_a_; std::vector<int32_t> ld_a_;
std::vector<int32_t> ld_b_; std::vector<int32_t> ld_b_;

View File

@@ -17,6 +17,7 @@ shift::shift(int B, int C,
int D, int H, int W, int D, int H, int W,
int T, int R, int S, int T, int R, int S,
int F, int F,
int stride_h, int stride_w,
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w, const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
type ty, bool bias) type ty, bool bias)
@@ -24,6 +25,7 @@ shift::shift(int B, int C,
AD_(D), AH_(H), AW_(W), AD_(D), AH_(H), AW_(W),
BD_(T), BH_(R), BW_(S), BD_(T), BH_(R), BW_(S),
F_(F), F_(F),
stride_d_(1), stride_h_(stride_h), stride_w_(stride_w),
shift_h_(shift_h), shift_w_(shift_w), shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty), a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) { ty_(ty), bias_(bias) {
@@ -33,17 +35,21 @@ shift::shift(int B, int C,
// transpose // transpose
AT_ = false; AT_ = false;
BT_ = true; BT_ = true;
// activation sizes
CD_ = AD_ / stride_d_;
CH_ = AH_ / stride_h_;
CW_ = AW_ / stride_w_;
// equivalent matmul // equivalent matmul
M_ = B_*AH_*AW_; M_ = B_*CH_*CW_;
N_ = F_; N_ = F_;
K_ = C_; K_ = C_;
// shapes // shapes
// input layout: C, H, W, B // input layout: C, H, W, B
// filter layout: C, F // filter layout: C, F
// output layout: F, H, W, B // output layout: F, H, W, B
shapes_a_ = {C, H, W, B}; shapes_a_ = {C, AH_, AW_, B};
shapes_b_ = {C, F}; shapes_b_ = {C, F};
shapes_c_ = {F, H, W, B}; shapes_c_ = {F, CH_, CW_, B};
if(ty_ == WGRAD){ if(ty_ == WGRAD){
shapes_b_.swap(shapes_c_); shapes_b_.swap(shapes_c_);
shapes_a_.swap(shapes_b_); shapes_a_.swap(shapes_b_);
@@ -51,14 +57,14 @@ shift::shift(int B, int C,
BT_ = false; BT_ = false;
M_ = F_; M_ = F_;
N_ = C_; N_ = C_;
K_ = B_*AH_*AW_; K_ = B_*CH_*CW_;
} }
if(ty_ == BPROP){ if(ty_ == BPROP){
shapes_a_.swap(shapes_c_); shapes_a_.swap(shapes_c_);
AT_ = false; AT_ = false;
BT_ = false; BT_ = false;
K_ = F_; K_ = F_;
M_ = B_*AH_*AW_; M_ = B_*CH_*CW_;
N_ = C_; N_ = C_;
} }
// memory strides // memory strides
@@ -133,13 +139,15 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_); kernel->setArg(3, M_);
kernel->setArg(4, N_); kernel->setArg(4, N_);
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, lda); kernel->setArg(6, stride_h_);
kernel->setArg(7, ldb); kernel->setArg(7, stride_w_);
kernel->setArg(8, B_); kernel->setArg(8, lda);
kernel->setArg(9, AH_); kernel->setArg(9, ldb);
kernel->setArg(10, AW_); kernel->setArg(10, B_);
kernel->setArg(11, BH_); kernel->setArg(11, AH_);
kernel->setArg(12, BW_); kernel->setArg(12, AW_);
kernel->setArg(13, BH_);
kernel->setArg(14, BW_);
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
if(ty_ == BPROP) if(ty_ == BPROP)
((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4); ((driver::cu_buffer*)c)->set_zero(stream, M_*N_*4);
@@ -188,6 +196,7 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
restrict read_only align(16) )" << b_ty_ << R"( *b, restrict read_only align(16) )" << b_ty_ << R"( *b,
fp32 *c, fp32 *c,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
int32 stride_h, int32 stride_w,
int32 lda, int32 ldb, int32 lda, int32 ldb,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
@@ -200,9 +209,9 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
if(ty_ == FPROP){ if(ty_ == FPROP){
os << R"( os << R"(
int32 rawhc[TM] = rxa / ABS; int32 rawhc[TM] = rxa / ABS;
int32 raw[TM] = rawhc % AW; int32 raw[TM] = (rawhc % AW)*stride_w;
int32 rahc[TM] = rawhc / AW; int32 rahc[TM] = rawhc / AW;
int32 rah[TM] = rahc % AH; int32 rah[TM] = (rahc % AH)*stride_h;
__constant__ int32* pd[TK] = delta + rka; __constant__ int32* pd[TK] = delta + rka;
multiple_of(4) int32 d[TK] = *pd; multiple_of(4) int32 d[TK] = *pd;
int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
@@ -227,9 +236,9 @@ if(ty_ == WGRAD){
if(ty_ == WGRAD){ if(ty_ == WGRAD){
os << R"( os << R"(
int32 rbwhc[TK] = rkb / ABS; int32 rbwhc[TK] = rkb / ABS;
int32 rbw[TK] = rbwhc % AW; int32 rbw[TK] = (rbwhc % AW)*stride_w;
int32 rbhc[TK] = rbwhc / AW; int32 rbhc[TK] = rbwhc / AW;
int32 rbh[TK] = rbhc % AH; int32 rbh[TK] = (rbhc % AH)*stride_h;
int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h)); int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));
int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w)); int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
@@ -266,9 +275,9 @@ if(ty_ == WGRAD){
pb = pb + TK)" << ldb0 << R"(; pb = pb + TK)" << ldb0 << R"(;
rkb = rkb + TK; rkb = rkb + TK;
rbwhc = rkb / ABS; rbwhc = rkb / ABS;
rbw = rbwhc % AW; rbw = (rbwhc % AW)*stride_w;
rbhc = rbwhc / AW; rbhc = rbwhc / AW;
rbh = rbhc % AH; rbh = (rbhc % AH)*stride_h;
interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h)); interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));
interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w)); interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));
interior = interiorh[:, newaxis] && interiorw[:, newaxis]; interior = interiorh[:, newaxis] && interiorw[:, newaxis];
@@ -292,9 +301,9 @@ else{
if(ty_ == BPROP){ if(ty_ == BPROP){
os << R"( os << R"(
int32 rcwhc[TM] = rxc / ABS; int32 rcwhc[TM] = rxc / ABS;
int32 rcw[TM] = rcwhc % AW; int32 rcw[TM] = (rcwhc % AW)*stride_w;
int32 rchc[TM] = rcwhc / AW; int32 rchc[TM] = rcwhc / AW;
int32 rch[TM] = rchc % AH; int32 rch[TM] = (rchc % AH)*stride_h;
int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h)); int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));
int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w)); int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis]; int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];