Backend: Fixed alpha, beta in GEMM.

This commit is contained in:
Philippe Tillet
2015-06-29 21:52:50 -07:00
parent cf2dba43ef
commit 9d0d50ba05
5 changed files with 16 additions and 14 deletions

View File

@@ -43,7 +43,7 @@ public:
ISAAC_INSTANTIATE(float)
ISAAC_INSTANTIATE(double)
#undef ISAAC_INSTANTIATE
value_scalar(values_holder value, numeric_type dtype);
value_scalar(values_holder values, numeric_type dtype);
explicit value_scalar(scalar const &);
explicit value_scalar(array_expression const &);
explicit value_scalar(numeric_type dtype);

View File

@@ -171,7 +171,7 @@ array& array::operator=(controller<TYPE> const & c)
template<class DT>
array & array::operator=(std::vector<DT> const & rhs)
{
assert(nshape()==1);
assert(nshape()<=1);
isaac::copy(rhs, *this);
return *this;
}
@@ -246,20 +246,20 @@ array_expression array::T() const
//---------------------------------------
scalar array::operator [](int_t idx)
{
assert(nshape()==1);
assert(nshape()<=1);
return scalar(dtype_, data_, idx);
}
const scalar array::operator [](int_t idx) const
{
assert(nshape()==1);
assert(nshape()<=1);
return scalar(dtype_, data_, idx);
}
array array::operator[](slice const & e1)
{
assert(nshape()==1);
assert(nshape()<=1);
return array(*this, e1);
}

View File

@@ -717,7 +717,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar beta(0, dtype);
if(args.beta) beta = value_scalar(args.beta->vscalar, dtype);
value_scalar alpha(1, dtype);
if(args.alpha) alpha = value_scalar(args.alpha->vscalar, dtype);
execution_options_type const & options = ctr.execution_options();

View File

@@ -48,7 +48,7 @@ INSTANTIATE(double)
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
values_holder value_scalar::values() const
{ return values_; }

View File

@@ -119,7 +119,6 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
}
}
int main()