More flexibility in scalars

This commit is contained in:
Philippe Tillet
2015-01-19 21:29:47 -05:00
parent 8694bacaab
commit 4f73fb384f
18 changed files with 127 additions and 113 deletions

View File

@@ -32,8 +32,8 @@ public:
//General constructor
array(numeric_type dtype, cl::Buffer data, slice const & s1, slice const & s2, int_t ld, cl::Context context = cl::default_context());
array(array_expression const & proxy);
array(array const &);
//Getters
numeric_type dtype() const;
size4 shape() const;
@@ -53,7 +53,7 @@ public:
array& operator=(array_expression const &);
template<class T> array & operator=(std::vector<T> const & rhs);
array& operator-();
array_expression operator-();
array& operator+=(value_scalar const &);
array& operator+=(array const &);
array& operator+=(array_expression const &);
@@ -93,10 +93,9 @@ public:
explicit scalar(value_scalar value, cl::Context context = cl::default_context());
explicit scalar(numeric_type dtype, cl::Context context = cl::default_context());
scalar(array_expression const & proxy);
scalar& operator=(value_scalar const &);
scalar& operator=(scalar const &);
using array::operator=;
// scalar& operator=(scalar const & s);
using array::operator =;
#define INSTANTIATE(type) operator type() const;
INSTANTIATE(bool)
@@ -212,7 +211,7 @@ ATIDLAS_DECLARE_REDUCTION(argmin)
//
std::ostream& operator<<(std::ostream &, array const &);
std::ostream& operator<<(std::ostream &, array_expression const &);
std::ostream& operator<<(std::ostream & os, scalar const & s);
}
#endif

View File

@@ -162,17 +162,6 @@ private:
std::string pointer_;
};
/** @brief Scalar
*
* Maps a scalar passed by pointer
*/
class mapped_scalar : public mapped_handle
{
public:
mapped_scalar(std::string const & scalartype, unsigned int id);
};
/** @brief Buffered
*
* Maps a buffered object (vector, matrix)

View File

@@ -218,19 +218,7 @@ struct array_expression: public symbolic_expression
array_expression& reshape(int_t size1, int_t size2=1);
int_t nshape() const;
array_expression& operator-();
array_expression& operator+=(value_scalar const &);
array_expression& operator+=(array const &);
array_expression& operator+=(array_expression const &);
array_expression& operator-=(value_scalar const &);
array_expression& operator-=(array const &);
array_expression& operator-=(array_expression const &);
array_expression& operator*=(value_scalar const &);
array_expression& operator*=(array const &);
array_expression& operator*=(array_expression const &);
array_expression& operator/=(value_scalar const &);
array_expression& operator/=(array const &);
array_expression& operator/=(array_expression const &);
array_expression operator-();
private:
size4 shape_;
};

View File

@@ -2,6 +2,7 @@
#define ATIDLAS_TOOLS_TO_STRING_HPP
#include <string>
#include <sstream>
namespace atidlas
{

View File

@@ -2,6 +2,7 @@
#define ATIDLAS_TYPES_H
#include "atidlas/cl/cl.hpp"
#include "atidlas/exception/unknown_datatype.h"
namespace atidlas
{
@@ -59,7 +60,7 @@ inline std::string numeric_type_to_string(numeric_type const & type)
case ULONG_TYPE: return "ulong";
case FLOAT_TYPE : return "float";
case DOUBLE_TYPE : return "double";
default : throw "Unsupported Scalartype";
default : throw unknown_datatype(type);
}
}
@@ -94,7 +95,7 @@ inline unsigned int size_of(numeric_type type)
case LONG_TYPE:
case DOUBLE_TYPE: return 8;
default: throw "Unsupported numeric type";
default: throw unknown_datatype(type);
}
}

View File

@@ -2,6 +2,7 @@
#include "atidlas/array.h"
#include "atidlas/cl/cl.hpp"
#include "atidlas/exception/unknown_datatype.h"
#include "atidlas/model/model.h"
#include "atidlas/symbolic/execute.h"
@@ -166,7 +167,29 @@ INSTANTIATE(cl_ulong);
INSTANTIATE(cl_float);
INSTANTIATE(cl_double);
#undef INSTANTIATE
array_expression array::operator-()
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
//
array & array::operator+=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
array & array::operator+=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), context_, dtype_, shape_); }
array & array::operator+=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_ADD_TYPE), shape_); }
//----
array & array::operator-=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
array & array::operator-=(array const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), context_, dtype_, shape_); }
array & array::operator-=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
//----
array & array::operator*=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), context_, dtype_, shape_); }
@@ -175,7 +198,7 @@ array & array::operator*=(array const & rhs)
array & array::operator*=(array_expression const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_MULT_TYPE), shape_); }
//----
array & array::operator/=(value_scalar const & rhs)
{ return *this = array_expression(*this, rhs, op_element(OPERATOR_BINARY_TYPE_FAMILY, OPERATOR_DIV_TYPE), context_, dtype_, shape_); }
@@ -193,6 +216,13 @@ scalar array::operator [](int_t idx)
return scalar(dtype_, data_, idx, context_);
}
const scalar array::operator [](int_t idx) const
{
assert(nshape()==1);
return scalar(dtype_, data_, idx, context_);
}
array array::operator[](slice const & e1)
{
assert(nshape()==1);
@@ -215,7 +245,7 @@ void copy(cl::Context & ctx, cl::Buffer const & data, T value)
}
scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,1), 1, context)
scalar::scalar(numeric_type dtype, const cl::Buffer &data, int_t offset, cl::Context context): array(dtype, data, _(offset, offset+1), _(1,2), 1, context)
{ }
scalar::scalar(value_scalar value, cl::Context context) : array(1, value.dtype(), context)
@@ -232,12 +262,13 @@ scalar::scalar(value_scalar value, cl::Context context) : array(1, value.dtype()
case ULONG_TYPE: detail::copy(context_, data_, (cl_ulong)value); break;
case FLOAT_TYPE: detail::copy(context_, data_, (cl_float)value); break;
case DOUBLE_TYPE: detail::copy(context_, data_, (cl_double)value); break;
default: throw "unrecognized datatype";
default: throw unknown_datatype(dtype_);
}
}
scalar::scalar(numeric_type dtype, cl::Context context) : array(1, dtype, context){ }
scalar::scalar(numeric_type dtype, cl::Context context) : array(1, dtype, context)
{ }
scalar::scalar(array_expression const & proxy) : array(proxy){ }
@@ -263,7 +294,7 @@ case DTYPE:\
HANDLE_CASE(ULONG_TYPE, uint64);
HANDLE_CASE(FLOAT_TYPE, float32);
HANDLE_CASE(DOUBLE_TYPE, float64);
default: throw "Datatype not recognized";
default: throw unknown_datatype(dtype_);
}
#undef HANDLE_CASE
@@ -292,12 +323,14 @@ scalar& scalar::operator=(value_scalar const & s)
HANDLE_CASE(ULONG_TYPE, cl_ulong)
HANDLE_CASE(FLOAT_TYPE, cl_float)
HANDLE_CASE(DOUBLE_TYPE, cl_double)
default: throw "Datatype not recognized";
default: throw unknown_datatype(dtype_);
}
}
scalar& scalar::operator=(scalar const & s)
{ return (scalar&)array::operator =(s); }
//scalar& scalar::operator=(scalar const & s)
//{
// return scalar::operator =(value_scalar(s));
//}
#define INSTANTIATE(type) scalar::operator type() const { return cast<type>(); }
INSTANTIATE(cl_char)
@@ -327,7 +360,7 @@ std::ostream & operator<<(std::ostream & os, scalar const & s)
case HALF_TYPE: return os << static_cast<cl_half>(s);
case FLOAT_TYPE: return os << static_cast<cl_float>(s);
case DOUBLE_TYPE: return os << static_cast<cl_double>(s);
default: throw "";
default: throw unknown_datatype(s.dtype());
}
}
@@ -489,7 +522,7 @@ array_expression OPNAME(array const & x, int_t axis)\
else if(axis==1)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), x.context(), x.dtype(), size4(x.shape()._2));\
else\
throw "invalid shape";\
throw ;\
}\
\
array_expression OPNAME(array_expression const & x, int_t axis)\
@@ -501,7 +534,7 @@ array_expression OPNAME(array_expression const & x, int_t axis)\
else if(axis==1)\
return array_expression(x, lhs_rhs_element(), op_element(OPERATOR_COLUMNS_REDUCTION_TYPE_FAMILY, OP), size4(x.shape()._2));\
else\
throw "invalid shape";\
throw ;\
}
DEFINE_REDUCTION(OPERATOR_ADD_TYPE, sum)

View File

@@ -167,9 +167,6 @@ mapped_tuple::mapped_tuple(std::string const & scalartype, unsigned int id, size
mapped_handle::mapped_handle(std::string const & scalartype, unsigned int id, std::string const & type_key) : mapped_object(scalartype, id, type_key)
{ register_attribute(pointer_, "#pointer", name_ + "_pointer"); }
//
mapped_scalar::mapped_scalar(std::string const & scalartype, unsigned int id) : mapped_handle(scalartype, id, "scalar") { }
//
mapped_buffer::mapped_buffer(std::string const & scalartype, unsigned int id, std::string const & type_key) : mapped_handle(scalartype, id, type_key){ }
@@ -223,9 +220,13 @@ void mapped_array::preprocess(std::string & str) const
replace_macro(str, "$OFFSET", MorphOffset(ld_, type_));
}
mapped_array::mapped_array(std::string const & scalartype, unsigned int id, char type) : mapped_buffer(scalartype, id, type=='m'?"array2":"array1"), type_(type)
mapped_array::mapped_array(std::string const & scalartype, unsigned int id, char type) : mapped_buffer(scalartype, id, type=='s'?"array0":(type=='m'?"array2":"array1")), type_(type)
{
if(type_=='m')
if(type_ == 's')
{
register_attribute(start1_, "#start", name_ + "_start");
}
else if(type_=='m')
{
register_attribute(start1_, "#start1", name_ + "_start1");
register_attribute(start2_, "#start2", name_ + "_start2");

View File

@@ -277,13 +277,13 @@ void evaluate_expression_traversal::operator()(atidlas::symbolic_expression cons
{
symbolic_expression_node const & root_node = symbolic_expression.tree()[root_idx];
mapping_type::key_type key = std::make_pair(root_idx, leaf);
if (leaf==PARENT_NODE_TYPE)
if (leaf==PARENT_NODE_TYPE && root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
{
if (detail::is_node_leaf(root_node.op))
str_ += mapping_.at(key)->evaluate(accessors_);
else if (detail::is_elementwise_operator(root_node.op))
str_ += evaluate(root_node.op.type);
else if (detail::is_elementwise_function(root_node.op) && root_node.op.type_family!=OPERATOR_UNARY_TYPE_FAMILY)
else if (detail::is_elementwise_function(root_node.op))
str_ += ",";
}
else

View File

@@ -9,6 +9,7 @@
#include "atidlas/backend/templates/base.h"
#include "atidlas/backend/parse.h"
#include "atidlas/exception/operation_not_supported.h"
#include "atidlas/exception/unknown_datatype.h"
#include "atidlas/tools/to_string.hpp"
#include "atidlas/tools/make_map.hpp"
#include "atidlas/symbolic/io.h"
@@ -46,12 +47,11 @@ tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a
{
std::string dtype = numeric_type_to_string(a.dtype);
unsigned int id = binder_.get(a.data);
//Scalar
if(a.shape1==1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
else
{
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
//Column vector
if(a.shape1>1 && a.shape2==1)
else if(a.shape1>1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector
else if(a.shape1==1 && a.shape2>1)
@@ -59,7 +59,6 @@ tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a
//Matrix
else
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
}
}
tools::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
@@ -131,7 +130,7 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
case ULONG_TYPE: kernel_.setArg(current_arg_++, scal.uint64); break;
case FLOAT_TYPE: kernel_.setArg(current_arg_++, scal.float32); break;
case DOUBLE_TYPE: kernel_.setArg(current_arg_++, scal.float64); break;
default: throw "Datatype not recognized";
default: throw ;
}
}
@@ -141,16 +140,14 @@ void base::set_arguments_functor::set_arguments(array_infos const & x) const
bool is_bound = binder_.bind(x.data);
if (is_bound)
{
kernel_.setArg(current_arg_++, x.data);
//scalar
if(x.shape1==1 && x.shape2==1)
{
kernel_.setArg(current_arg_++, x.data);
kernel_.setArg(current_arg_++, cl_uint(x.start1));
}
//array
else
{
kernel_.setArg(current_arg_++, x.data);
if(x.shape1==1 || x.shape2==1)
else if(x.shape1==1 || x.shape2==1)
{
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
@@ -164,7 +161,6 @@ void base::set_arguments_functor::set_arguments(array_infos const & x) const
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
}
}
}
}
void base::set_arguments_functor::set_arguments(repeat_infos const & i) const
@@ -182,7 +178,7 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
default: throw "oh noez";
default: throw ;
}
}
@@ -269,7 +265,7 @@ std::string base::generate_arguments(std::vector<mapping_type> const & mappings,
std::string base::generate_arguments(std::string const & data_type, std::vector<mapping_type> const & mappings, symbolic_expressions_container const & symbolic_expressions)
{
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("scalar", "__global #scalartype* #pointer,")
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("array0", "__global #scalartype* #pointer, uint #start,")
("host_scalar", "#scalartype #name,")
("array1", "__global " + data_type + "* #pointer, uint #start, uint #stride,")
("array2", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")

View File

@@ -34,7 +34,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
stream << "{" << std::endl;
stream.inc_tab();
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array0", "#scalartype #namereg = #pointer[#start];")
("array1", "#pointer += #start;")
("array2", "#pointer = &$VALUE{#start1, #start2};"), symbolic_expressions, mappings);
@@ -58,7 +58,7 @@ std::string maxpy::generate_impl(unsigned int label, symbolic_expressions_contai
("array2", "#namereg")
("vdiag", "#namereg")
("repeat", "#namereg")
("scalar", "#namereg")
("array0", "#namereg")
("outer", "#namereg")
, symbolic_expressions, mappings);

View File

@@ -42,7 +42,7 @@ std::string mreduction::generate_impl(unsigned int label, symbolic_expressions_c
stream.inc_tab();
process(stream, PARENT_NODE_TYPE,
tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
tools::make_map<std::map<std::string, std::string> >("array0", "#scalartype #namereg = #pointer[#start];")
("array1", "#pointer += #start;")
("array2", "#pointer += #start1 + #start2*#ld; "
"#ld *= #nldstride; "), symbolic_expressions, mappings);
@@ -105,7 +105,7 @@ std::string mreduction::generate_impl(unsigned int label, symbolic_expressions_c
std::map<std::string, std::string> accessors;
accessors["array2"] = str[a];
accessors["repeat"] = "#namereg";
accessors["scalar"] = "#namereg";
accessors["array0"] = "#namereg";
std::string value = exprs[k]->evaluate_recursive(LHS_NODE_TYPE, accessors);
if (exprs[k]->is_index_reduction())
compute_index_reduction(stream, exprs[k]->process("#name_acc"), "c*"+to_string(simd_width) + to_string(a), exprs[k]->process("#name_acc_value"), value,exprs[k]->root_op());

View File

@@ -91,7 +91,7 @@ std::string reduction::generate_impl(unsigned int label, char type, symbolic_exp
stream.inc_tab();
stream << "unsigned int lid = get_local_id(0);" << std::endl;
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
process(stream, PARENT_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array0", "#scalartype #namereg = #pointer[#start];")
("array1", "#pointer += #start;"), symbolic_expressions, mappings);
for (unsigned int k = 0; k < N; ++k)
@@ -143,7 +143,7 @@ std::string reduction::generate_impl(unsigned int label, char type, symbolic_exp
accessors["matrix_row"] = str[a];
accessors["matrix_column"] = str[a];
accessors["matrix_diag"] = str[a];
accessors["scalar"] = "#namereg";
accessors["array0"] = "#namereg";
std::string value = exprs[k]->evaluate_recursive(LHS_NODE_TYPE, accessors);
if (exprs[k]->is_index_reduction())
compute_index_reduction(stream, exprs[k]->process("#name_acc"), "i*" + tools::to_string(simd_width) + "+"
@@ -242,8 +242,7 @@ std::string reduction::generate_impl(unsigned int label, char type, symbolic_exp
stream.inc_tab();
std::map<std::string, std::string> accessors;
accessors["scalar_reduction"] = "#name_buf[0]";
accessors["scalar"] = "*#pointer";
accessors["array1"] = "#pointer[#start]";
accessors["array0"] = "#pointer[#start]";
evaluate(stream, PARENT_NODE_TYPE, accessors, symbolic_expressions, mappings);
stream.dec_tab();
stream << "}" << std::endl;

View File

@@ -39,7 +39,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
stream.inc_tab();
process(stream, PARENT_NODE_TYPE,
tools::make_map<std::map<std::string, std::string> >("scalar", "#scalartype #namereg = *#pointer;")
tools::make_map<std::map<std::string, std::string> >("array0", "#scalartype #namereg = #pointer[#start];")
("array1", "#pointer += #start;")
("array1", "#start1/=" + str_simd_width + ";"), symbolic_expressions, mappings);
@@ -59,7 +59,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
("matrix_row", "#namereg")
("matrix_column", "#namereg")
("matrix_diag", "#namereg")
("scalar", "#namereg"), symbolic_expressions, mappings);
("array0", "#namereg"), symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array1", "#pointer[i*#stride] = #namereg;")
("matrix_row", "$VALUE{#row, i} = #namereg;")
@@ -73,7 +73,7 @@ std::vector<std::string> vaxpy::generate_impl(unsigned int label, symbolic_expre
stream << "if(get_global_id(0)==0)" << std::endl;
stream << "{" << std::endl;
stream.inc_tab();
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("scalar", "*#pointer = #namereg;"), symbolic_expressions, mappings);
process(stream, LHS_NODE_TYPE, tools::make_map<std::map<std::string, std::string> >("array0", "#pointer[#start] = #namereg;"), symbolic_expressions, mappings);
stream.dec_tab();
stream << "}" << std::endl;

View File

@@ -150,14 +150,14 @@ namespace detail
if(name=="matrix-productNT") return MATRIX_PRODUCT_NT_TYPE;
if(name=="matrix-productTN") return MATRIX_PRODUCT_TN_TYPE;
if(name=="matrix-productTT") return MATRIX_PRODUCT_TT_TYPE;
throw "Unsupported operation";
throw ;
}
static numeric_type get_dtype(std::string const & name)
{
if(name=="float32") return FLOAT_TYPE;
if(name=="float64") return DOUBLE_TYPE;
throw "Unsupported operation";
throw;
}
static tools::shared_ptr<base> create(std::string const & template_name, std::vector<int> const & a)

View File

@@ -196,7 +196,7 @@ namespace atidlas
case MATRIX_PRODUCT_TN_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape2, dtype, context)); break;
case MATRIX_PRODUCT_TT_TYPE: tmp = tools::shared_ptr<obj_base>(new array(node.lhs.array.shape2, node.rhs.array.shape1, dtype, context)); break;
default: throw "This shouldn't happen. Ever.";
default: throw ;
}
temporaries_.push_back(tmp);

View File

@@ -21,6 +21,10 @@ void fill(array const & a, array_infos& i)
i.ld = a.ld();
}
array_expression array_expression::operator-()
{ return array_expression(*this, lhs_rhs_element(), op_element(OPERATOR_UNARY_TYPE_FAMILY, OPERATOR_SUB_TYPE), shape_); }
lhs_rhs_element::lhs_rhs_element()
{
type_family = INVALID_TYPE_FAMILY;

View File

@@ -2,6 +2,7 @@
#include <iostream>
#include "atidlas/array.h"
#include "atidlas/value_scalar.h"
#include "atidlas/exception/unknown_datatype.h"
namespace atidlas
{
@@ -10,17 +11,17 @@ void value_scalar::init(scalar const & s)
{
switch(dtype_)
{
case CHAR_TYPE: values_.int8 = s;
case UCHAR_TYPE: values_.uint8 = s;
case SHORT_TYPE: values_.int16 = s;
case USHORT_TYPE: values_.uint16 = s;
case INT_TYPE: values_.int32 = s;
case UINT_TYPE: values_.uint32 = s;
case LONG_TYPE: values_.int64 = s;
case ULONG_TYPE: values_.uint64 = s;
case FLOAT_TYPE: values_.float32 = s;
case DOUBLE_TYPE: values_.float64 = s;
default: throw;
case CHAR_TYPE: values_.int8 = s; break;
case UCHAR_TYPE: values_.uint8 = s; break;
case SHORT_TYPE: values_.int16 = s; break;
case USHORT_TYPE: values_.uint16 = s; break;
case INT_TYPE: values_.int32 = s; break;
case UINT_TYPE: values_.uint32 = s; break;
case LONG_TYPE: values_.int64 = s; break;
case ULONG_TYPE: values_.uint64 = s; break;
case FLOAT_TYPE: values_.float32 = s; break;
case DOUBLE_TYPE: values_.float64 = s; break;
default: throw unknown_datatype(dtype_);
}
}
@@ -61,11 +62,12 @@ T value_scalar::cast() const
// case HALF_TYPE: return values_.float16;
case FLOAT_TYPE: return values_.float32;
case DOUBLE_TYPE: return values_.float64;
default: throw; //unreachable
default: throw unknown_datatype(dtype_); //unreachable
}
}
#define INSTANTIATE(type) value_scalar::operator type() const { return cast<type>(); }
INSTANTIATE(bool)
INSTANTIATE(cl_char)
INSTANTIATE(cl_uchar)
INSTANTIATE(cl_short)
@@ -105,7 +107,7 @@ value_scalar NAME(LDEC, RDEC)\
case ULONG_TYPE: return VALUE(cl_ulong, OP, x, y);\
case FLOAT_TYPE: return VALUE(cl_float, OP, x, y);\
case DOUBLE_TYPE: return VALUE(cl_double, OP, x, y);\
default: throw;\
default: throw unknown_datatype(x.dtype());\
}\
}
@@ -168,7 +170,7 @@ std::ostream & operator<<(std::ostream & os, value_scalar const & s)
case ULONG_TYPE: return os << static_cast<cl_ulong>(s);
case FLOAT_TYPE: return os << static_cast<cl_float>(s);
case DOUBLE_TYPE: return os << static_cast<cl_double>(s);
default: throw "";
default: throw unknown_datatype(s.dtype());;
}
}

View File

@@ -39,6 +39,7 @@ void test_element_wise_vector(T epsilon, simple_vector_base<T> & cx, simple_vect
}
RUN_TEST_VECTOR_AXPY("z = x", cz[i] = cx[i], z = x)
RUN_TEST_VECTOR_AXPY("z = -x", cz[i] = -cx[i], z = -x)
RUN_TEST_VECTOR_AXPY("z = x + y", cz[i] = cx[i] + cy[i], z = x + y)
RUN_TEST_VECTOR_AXPY("z = x - y", cz[i] = cx[i] - cy[i], z = x - y)