Backend: Fixed alpha, beta in GEMM.
This commit is contained in:
@@ -43,7 +43,7 @@ public:
|
|||||||
ISAAC_INSTANTIATE(float)
|
ISAAC_INSTANTIATE(float)
|
||||||
ISAAC_INSTANTIATE(double)
|
ISAAC_INSTANTIATE(double)
|
||||||
#undef ISAAC_INSTANTIATE
|
#undef ISAAC_INSTANTIATE
|
||||||
value_scalar(values_holder value, numeric_type dtype);
|
value_scalar(values_holder values, numeric_type dtype);
|
||||||
explicit value_scalar(scalar const &);
|
explicit value_scalar(scalar const &);
|
||||||
explicit value_scalar(array_expression const &);
|
explicit value_scalar(array_expression const &);
|
||||||
explicit value_scalar(numeric_type dtype);
|
explicit value_scalar(numeric_type dtype);
|
||||||
|
@@ -171,7 +171,7 @@ array& array::operator=(controller<TYPE> const & c)
|
|||||||
template<class DT>
|
template<class DT>
|
||||||
array & array::operator=(std::vector<DT> const & rhs)
|
array & array::operator=(std::vector<DT> const & rhs)
|
||||||
{
|
{
|
||||||
assert(nshape()==1);
|
assert(nshape()<=1);
|
||||||
isaac::copy(rhs, *this);
|
isaac::copy(rhs, *this);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@@ -246,20 +246,20 @@ array_expression array::T() const
|
|||||||
//---------------------------------------
|
//---------------------------------------
|
||||||
scalar array::operator [](int_t idx)
|
scalar array::operator [](int_t idx)
|
||||||
{
|
{
|
||||||
assert(nshape()==1);
|
assert(nshape()<=1);
|
||||||
return scalar(dtype_, data_, idx);
|
return scalar(dtype_, data_, idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
const scalar array::operator [](int_t idx) const
|
const scalar array::operator [](int_t idx) const
|
||||||
{
|
{
|
||||||
assert(nshape()==1);
|
assert(nshape()<=1);
|
||||||
return scalar(dtype_, data_, idx);
|
return scalar(dtype_, data_, idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
array array::operator[](slice const & e1)
|
array array::operator[](slice const & e1)
|
||||||
{
|
{
|
||||||
assert(nshape()==1);
|
assert(nshape()<=1);
|
||||||
return array(*this, e1);
|
return array(*this, e1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -717,7 +717,10 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
|
|||||||
|
|
||||||
|
|
||||||
value_scalar beta(0, dtype);
|
value_scalar beta(0, dtype);
|
||||||
|
if(args.beta) beta = value_scalar(args.beta->vscalar, dtype);
|
||||||
|
|
||||||
value_scalar alpha(1, dtype);
|
value_scalar alpha(1, dtype);
|
||||||
|
if(args.alpha) alpha = value_scalar(args.alpha->vscalar, dtype);
|
||||||
|
|
||||||
|
|
||||||
execution_options_type const & options = ctr.execution_options();
|
execution_options_type const & options = ctr.execution_options();
|
||||||
|
@@ -48,7 +48,7 @@ INSTANTIATE(double)
|
|||||||
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
|
value_scalar::value_scalar(numeric_type dtype) : dtype_(dtype) {}
|
||||||
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
|
value_scalar::value_scalar(scalar const & s) : dtype_(s.dtype()) { s.inject(values_); }
|
||||||
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
|
value_scalar::value_scalar(array_expression const &expr) : dtype_(expr.dtype()) { scalar(expr).inject(values_); }
|
||||||
|
value_scalar::value_scalar(values_holder values, numeric_type dtype) : values_(values), dtype_(dtype) {}
|
||||||
values_holder value_scalar::values() const
|
values_holder value_scalar::values() const
|
||||||
{ return values_; }
|
{ return values_; }
|
||||||
|
|
||||||
|
@@ -119,7 +119,6 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
|
|||||||
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
|
test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
|
||||||
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
|
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
|
Reference in New Issue
Block a user