More flexibility in scalars
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include "atidlas/backend/templates/base.h"
|
||||
#include "atidlas/backend/parse.h"
|
||||
#include "atidlas/exception/operation_not_supported.h"
|
||||
#include "atidlas/exception/unknown_datatype.h"
|
||||
#include "atidlas/tools/to_string.hpp"
|
||||
#include "atidlas/tools/make_map.hpp"
|
||||
#include "atidlas/symbolic/io.h"
|
||||
@@ -46,20 +47,18 @@ tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a
|
||||
{
|
||||
std::string dtype = numeric_type_to_string(a.dtype);
|
||||
unsigned int id = binder_.get(a.data);
|
||||
//Scalar
|
||||
if(a.shape1==1 && a.shape2==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
|
||||
//Column vector
|
||||
else if(a.shape1>1 && a.shape2==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
|
||||
//Row vector
|
||||
else if(a.shape1==1 && a.shape2>1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
|
||||
//Matrix
|
||||
else
|
||||
{
|
||||
//Column vector
|
||||
if(a.shape1>1 && a.shape2==1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
|
||||
//Row vector
|
||||
else if(a.shape1==1 && a.shape2>1)
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
|
||||
//Matrix
|
||||
else
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
|
||||
}
|
||||
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
|
||||
}
|
||||
|
||||
tools::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
|
||||
@@ -131,7 +130,7 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
|
||||
case ULONG_TYPE: kernel_.setArg(current_arg_++, scal.uint64); break;
|
||||
case FLOAT_TYPE: kernel_.setArg(current_arg_++, scal.float32); break;
|
||||
case DOUBLE_TYPE: kernel_.setArg(current_arg_++, scal.float64); break;
|
||||
default: throw "Datatype not recognized";
|
||||
default: throw ;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,28 +140,25 @@ void base::set_arguments_functor::set_arguments(array_infos const & x) const
|
||||
bool is_bound = binder_.bind(x.data);
|
||||
if (is_bound)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data);
|
||||
//scalar
|
||||
if(x.shape1==1 && x.shape2==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data);
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start1));
|
||||
}
|
||||
//array
|
||||
else if(x.shape1==1 || x.shape2==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data);
|
||||
if(x.shape1==1 || x.shape2==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
|
||||
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
|
||||
}
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,7 +178,7 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
|
||||
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
|
||||
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
|
||||
case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
|
||||
default: throw "oh noez";
|
||||
default: throw ;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +265,7 @@ std::string base::generate_arguments(std::vector<mapping_type> const & mappings,
|
||||
|
||||
std::string base::generate_arguments(std::string const & data_type, std::vector<mapping_type> const & mappings, symbolic_expressions_container const & symbolic_expressions)
|
||||
{
|
||||
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("scalar", "__global #scalartype* #pointer,")
|
||||
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("array0", "__global #scalartype* #pointer, uint #start,")
|
||||
("host_scalar", "#scalartype #name,")
|
||||
("array1", "__global " + data_type + "* #pointer, uint #start, uint #stride,")
|
||||
("array2", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")
|
||||
|
Reference in New Issue
Block a user