Backend: Fixed alpha, beta in GEMM.
This commit is contained in:
@@ -43,7 +43,7 @@ public:
|
||||
ISAAC_INSTANTIATE(float)
|
||||
ISAAC_INSTANTIATE(double)
|
||||
#undef ISAAC_INSTANTIATE
|
||||
value_scalar(values_holder value, numeric_type dtype);
|
||||
value_scalar(values_holder values, numeric_type dtype);
|
||||
explicit value_scalar(scalar const &);
|
||||
explicit value_scalar(array_expression const &);
|
||||
explicit value_scalar(numeric_type dtype);
|
||||
|
@@ -171,7 +171,7 @@ array& array::operator=(controller<TYPE> const & c)
|
||||
template<class DT>
|
||||
array & array::operator=(std::vector<DT> const & rhs)
|
||||
{
|
||||
assert(nshape()==1);
|
||||
assert(nshape()<=1);
|
||||
isaac::copy(rhs, *this);
|
||||
return *this;
|
||||
}
|
||||
@@ -246,20 +246,20 @@ array_expression array::T() const
|
||||
//---------------------------------------
|
||||
scalar array::operator [](int_t idx)
|
||||
{
|
||||
assert(nshape()==1);
|
||||
assert(nshape()<=1);
|
||||
return scalar(dtype_, data_, idx);
|
||||
}
|
||||
|
||||
const scalar array::operator [](int_t idx) const
|
||||
{
|
||||
assert(nshape()==1);
|
||||
assert(nshape()<=1);
|
||||
return scalar(dtype_, data_, idx);
|
||||
}
|
||||
|
||||
|
||||
array array::operator[](slice const & e1)
|
||||
{
|
||||
assert(nshape()==1);
|
||||
assert(nshape()<=1);
|
||||
return array(*this, e1);
|
||||
}
|
||||
|
||||
|
@@ -717,7 +717,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
||||
|
||||
|
||||
value_scalar beta(0, dtype);
|
||||
if(args.beta) beta = value_scalar(args.beta->vscalar, dtype);
|
||||
|
||||
value_scalar alpha(1, dtype);
|
||||
if(args.alpha) alpha = value_scalar(args.alpha->vscalar, dtype);
|
||||
|
||||
|
||||
execution_options_type const & options = ctr.execution_options();
|
||||
|
@@ -48,7 +48,7 @@ INSTANTIATE(double)
|
||||
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
|
||||
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
|
||||
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
|
||||
|
||||
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
|
||||
values_holder value_scalar::values() const
|
||||
{ return values_; }
|
||||
|
||||
|
@@ -119,7 +119,6 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
|
||||
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
|
||||
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main()
|
||||
|
Reference in New Issue
Block a user