More flexibility in scalars

This commit is contained in:
Philippe Tillet
2015-01-19 21:29:47 -05:00
parent 8694bacaab
commit 4f73fb384f
18 changed files with 127 additions and 113 deletions

View File

@@ -9,6 +9,7 @@
#include "atidlas/backend/templates/base.h"
#include "atidlas/backend/parse.h"
#include "atidlas/exception/operation_not_supported.h"
#include "atidlas/exception/unknown_datatype.h"
#include "atidlas/tools/to_string.hpp"
#include "atidlas/tools/make_map.hpp"
#include "atidlas/symbolic/io.h"
@@ -46,20 +47,18 @@ tools::shared_ptr<mapped_object> base::map_functor::create(array_infos const & a
{
std::string dtype = numeric_type_to_string(a.dtype);
unsigned int id = binder_.get(a.data);
//Scalar
if(a.shape1==1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_scalar(dtype, id));
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 's'));
//Column vector
else if(a.shape1>1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector
else if(a.shape1==1 && a.shape2>1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
//Matrix
else
{
//Column vector
if(a.shape1>1 && a.shape2==1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'c'));
//Row vector
else if(a.shape1==1 && a.shape2>1)
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'r'));
//Matrix
else
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
}
return tools::shared_ptr<mapped_object>(new mapped_array(dtype, id, 'm'));
}
tools::shared_ptr<mapped_object> base::map_functor::create(repeat_infos const &) const
@@ -131,7 +130,7 @@ void base::set_arguments_functor::set_arguments(numeric_type dtype, values_holde
case ULONG_TYPE: kernel_.setArg(current_arg_++, scal.uint64); break;
case FLOAT_TYPE: kernel_.setArg(current_arg_++, scal.float32); break;
case DOUBLE_TYPE: kernel_.setArg(current_arg_++, scal.float64); break;
default: throw "Datatype not recognized";
default: throw ;
}
}
@@ -141,28 +140,25 @@ void base::set_arguments_functor::set_arguments(array_infos const & x) const
bool is_bound = binder_.bind(x.data);
if (is_bound)
{
kernel_.setArg(current_arg_++, x.data);
//scalar
if(x.shape1==1 && x.shape2==1)
{
kernel_.setArg(current_arg_++, x.data);
kernel_.setArg(current_arg_++, cl_uint(x.start1));
}
//array
else if(x.shape1==1 || x.shape2==1)
{
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
}
else
{
kernel_.setArg(current_arg_++, x.data);
if(x.shape1==1 || x.shape2==1)
{
kernel_.setArg(current_arg_++, cl_uint(std::max(x.start1, x.start2)));
kernel_.setArg(current_arg_++, cl_uint(std::max(x.stride1, x.stride2)));
}
else
{
kernel_.setArg(current_arg_++, cl_uint(x.ld));
kernel_.setArg(current_arg_++, cl_uint(x.start1));
kernel_.setArg(current_arg_++, cl_uint(x.start2));
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
}
kernel_.setArg(current_arg_++, cl_uint(x.ld));
kernel_.setArg(current_arg_++, cl_uint(x.start1));
kernel_.setArg(current_arg_++, cl_uint(x.start2));
kernel_.setArg(current_arg_++, cl_uint(x.stride1));
kernel_.setArg(current_arg_++, cl_uint(x.stride2));
}
}
}
@@ -182,7 +178,7 @@ void base::set_arguments_functor::set_arguments(lhs_rhs_element const & lhs_rhs)
case VALUE_TYPE_FAMILY: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case ARRAY_TYPE_FAMILY: return set_arguments(lhs_rhs.array);
case INFOS_TYPE_FAMILY: return set_arguments(lhs_rhs.tuple);
default: throw "oh noez";
default: throw ;
}
}
@@ -269,7 +265,7 @@ std::string base::generate_arguments(std::vector<mapping_type> const & mappings,
std::string base::generate_arguments(std::string const & data_type, std::vector<mapping_type> const & mappings, symbolic_expressions_container const & symbolic_expressions)
{
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("scalar", "__global #scalartype* #pointer,")
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("array0", "__global #scalartype* #pointer, uint #start,")
("host_scalar", "#scalartype #name,")
("array1", "__global " + data_type + "* #pointer, uint #start, uint #stride,")
("array2", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")