Backend: Fixed alpha, beta in GEMM.

This commit is contained in:
Philippe Tillet
2015-06-29 21:52:50 -07:00
parent cf2dba43ef
commit 9d0d50ba05
5 changed files with 16 additions and 14 deletions

View File

@@ -43,7 +43,7 @@ public:
ISAAC_INSTANTIATE(float) ISAAC_INSTANTIATE(float)
ISAAC_INSTANTIATE(double) ISAAC_INSTANTIATE(double)
#undef ISAAC_INSTANTIATE #undef ISAAC_INSTANTIATE
value_scalar(values_holder value, numeric_type dtype); value_scalar(values_holder values, numeric_type dtype);
explicit value_scalar(scalar const &); explicit value_scalar(scalar const &);
explicit value_scalar(array_expression const &); explicit value_scalar(array_expression const &);
explicit value_scalar(numeric_type dtype); explicit value_scalar(numeric_type dtype);

View File

@@ -171,7 +171,7 @@ array& array::operator=(controller<TYPE> const & c)
template<class DT> template<class DT>
array & array::operator=(std::vector<DT> const & rhs) array & array::operator=(std::vector<DT> const & rhs)
{ {
assert(nshape()==1); assert(nshape()<=1);
isaac::copy(rhs, *this); isaac::copy(rhs, *this);
return *this; return *this;
} }
@@ -246,20 +246,20 @@ array_expression array::T() const
//--------------------------------------- //---------------------------------------
scalar array::operator [](int_t idx) scalar array::operator [](int_t idx)
{ {
assert(nshape()==1); assert(nshape()<=1);
return scalar(dtype_, data_, idx); return scalar(dtype_, data_, idx);
} }
const scalar array::operator [](int_t idx) const const scalar array::operator [](int_t idx) const
{ {
assert(nshape()==1); assert(nshape()<=1);
return scalar(dtype_, data_, idx); return scalar(dtype_, data_, idx);
} }
array array::operator[](slice const & e1) array array::operator[](slice const & e1)
{ {
assert(nshape()==1); assert(nshape()<=1);
return array(*this, e1); return array(*this, e1);
} }

View File

@@ -717,7 +717,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
value_scalar beta(0, dtype); value_scalar beta(0, dtype);
if(args.beta) beta = value_scalar(args.beta->vscalar, dtype);
value_scalar alpha(1, dtype); value_scalar alpha(1, dtype);
if(args.alpha) alpha = value_scalar(args.alpha->vscalar, dtype);
execution_options_type const & options = ctr.execution_options(); execution_options_type const & options = ctr.execution_options();

View File

@@ -48,7 +48,7 @@ INSTANTIATE(double)
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {} value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); } value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); } value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
values_holder value_scalar::values() const values_holder value_scalar::values() const
{ return values_; } { return values_; }

View File

@@ -119,7 +119,6 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL"); test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB"); test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
} }
} }
int main() int main()