[dnn/shift] bugfix in wgrad
This commit is contained in:
@@ -68,6 +68,8 @@ public:
|
||||
// shapes for b
|
||||
int64_t Cb = tf_b.dim_size(0);
|
||||
F = tf_b.dim_size(1);
|
||||
if(OP == triton::dnn::shift::BPROP)
|
||||
std::swap(Cb, F);
|
||||
// checks
|
||||
OP_REQUIRES(context, Ca == Cb, tensorflow::errors::InvalidArgument("operands must have the same number of channels"));
|
||||
C = Ca;
|
||||
|
Reference in New Issue
Block a user