various changes
This commit is contained in:
@@ -107,8 +107,6 @@ void base::map_functor::operator()(atidlas::symbolic_expression const & symbolic
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_TRANS_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_trans>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
|
||||
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_)));
|
||||
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
|
||||
@@ -152,11 +150,19 @@ void base::set_arguments_functor::set_arguments(array const & x) const
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, x.data());
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
|
||||
if(x.nshape()==1)
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(max(x.start())));
|
||||
kernel_.setArg(current_arg_++, cl_uint(max(x.stride())));
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
|
||||
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -265,7 +271,8 @@ std::string base::generate_arguments(std::string const & data_type, std::vector<
|
||||
{
|
||||
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("scalar", "__global #scalartype* #pointer,")
|
||||
("host_scalar", "#scalartype #name,")
|
||||
("array", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")
|
||||
("array1", "__global " + data_type + "* #pointer, uint #start, uint #stride,")
|
||||
("array2", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")
|
||||
("tuple4", "#scalartype #name0, #scalartype #name1, #scalartype #name2, #scalartype #name3,"), symbolic_expressions);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user