[dnn/shift] bugfix in wgrad

This commit is contained in:
Philippe Tillet
2019-07-06 11:27:49 -07:00
parent 3e49dbe6ab
commit b0cf3143c5
3 changed files with 5 additions and 3 deletions

View File

@@ -68,6 +68,8 @@ public:
// shapes for b
int64_t Cb = tf_b.dim_size(0);
F = tf_b.dim_size(1);
if(OP == triton::dnn::shift::BPROP)
std::swap(Cb, F);
// checks
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
C = Ca;