various changes

This commit is contained in:
Philippe Tillet
2015-01-17 15:47:52 -05:00
parent 0068560bc6
commit 16648f18e0
14 changed files with 134 additions and 121 deletions

View File

@@ -107,8 +107,6 @@ void base::map_functor::operator()(atidlas::symbolic_expression const & symbolic
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mreduction>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type_family == OPERATOR_MATRIX_PRODUCT_TYPE_FAMILY)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_mproduct>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_TRANS_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_trans>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_REPEAT_TYPE)
mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_repeat>(&symbolic_expression, root_idx, &mapping_)));
else if (root_node.op.type == OPERATOR_OUTER_PROD_TYPE)
@@ -152,11 +150,19 @@ void base::set_arguments_functor::set_arguments(array const & x) const
else
{
kernel_.setArg(current_arg_++, x.data());
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
if(x.nshape()==1)
{
kernel_.setArg(current_arg_++, cl_uint(max(x.start())));
kernel_.setArg(current_arg_++, cl_uint(max(x.stride())));
}
else
{
kernel_.setArg(current_arg_++, cl_uint(x.ld()));
kernel_.setArg(current_arg_++, cl_uint(x.start()._1));
kernel_.setArg(current_arg_++, cl_uint(x.start()._2));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._1));
kernel_.setArg(current_arg_++, cl_uint(x.stride()._2));
}
}
}
}
@@ -265,7 +271,8 @@ std::string base::generate_arguments(std::string const & data_type, std::vector<
{
return generate_arguments(mappings, tools::make_map<std::map<std::string, std::string> >("scalar", "__global #scalartype* #pointer,")
("host_scalar", "#scalartype #name,")
("array", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")
("array1", "__global " + data_type + "* #pointer, uint #start, uint #stride,")
("array2", "__global " + data_type + "* #pointer, uint #ld, uint #start1, uint #start2, uint #stride1, uint #stride2,")
("tuple4", "#scalartype #name0, #scalartype #name1, #scalartype #name2, #scalartype #name3,"), symbolic_expressions);
}