Files
triton/lib/kernels/templates/tools/arguments.hpp
2015-12-19 03:29:51 -05:00

122 lines
4.9 KiB
C++

#include <string>
#include <vector>
#include <memory>
#include <algorithm>
#include "isaac/kernels/mapped_object.h"
#include "isaac/kernels/parse.h"
#include "isaac/array.h"
namespace isaac
{
namespace templates
{
//Generate
inline std::string generate_arguments(std::string const &, driver::Device const & device, mapping_type const & mappings, expression_tree const & expressions)
{
std::string kwglobal = Global(device.backend()).get();
std::string _size_t = size_type(device);
kernel_generation_stream stream;
process(stream, PARENT_NODE_TYPE, { {"array11", kwglobal + " #scalartype* #pointer, " + _size_t + " #start,"},
{"array1", kwglobal + " #scalartype* #pointer, " + _size_t + " #start,"},
{"host_scalar", "#scalartype #name,"},
{"arrayn", kwglobal + " #scalartype* #pointer, " + _size_t + " #start, " + _size_t + " #stride,"},
{"array1n", kwglobal + " #scalartype* #pointer, " + _size_t + " #start, " + _size_t + " #stride,"},
{"arrayn1", kwglobal + " #scalartype* #pointer, " + _size_t + " #start, " + _size_t + " #stride,"},
{"arraynn", kwglobal + " #scalartype* #pointer, " + _size_t + " #start, " + _size_t + " #stride," + _size_t + " #ld,"},
{"tuple4", "#scalartype #name0, #scalartype #name1, #scalartype #name2, #scalartype #name3,"}}
, expressions, mappings);
std::string res = stream.str();
res.erase(res.rfind(','));
return res;
}
//Enqueue
class set_arguments_functor : public traversal_functor
{
public:
typedef void result_type;
set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, driver::Kernel & kernel)
: binder_(binder), current_arg_(current_arg), kernel_(kernel)
{
}
void set_arguments(numeric_type dtype, values_holder const & scal) const
{
switch(dtype)
{
// case BOOL_TYPE: kernel_.setArg(current_arg_++, scal.bool8); break;
case CHAR_TYPE: kernel_.setArg(current_arg_++, scal.int8); break;
case UCHAR_TYPE: kernel_.setArg(current_arg_++, scal.uint8); break;
case SHORT_TYPE: kernel_.setArg(current_arg_++, scal.int16); break;
case USHORT_TYPE: kernel_.setArg(current_arg_++, scal.uint16); break;
case INT_TYPE: kernel_.setArg(current_arg_++, scal.int32); break;
case UINT_TYPE: kernel_.setArg(current_arg_++, scal.uint32); break;
case LONG_TYPE: kernel_.setArg(current_arg_++, scal.int64); break;
case ULONG_TYPE: kernel_.setArg(current_arg_++, scal.uint64); break;
// case HALF_TYPE: kernel_.setArg(current_arg_++, scal.float16); break;
case FLOAT_TYPE: kernel_.setArg(current_arg_++, scal.float32); break;
case DOUBLE_TYPE: kernel_.setArg(current_arg_++, scal.float64); break;
default: throw unknown_datatype(dtype);
}
}
void set_arguments(array_base const * a, bool is_assigned) const
{
bool is_bound = binder_.bind(a, is_assigned);
if (is_bound)
{
kernel_.setArg(current_arg_++, a->data());
kernel_.setSizeArg(current_arg_++, a->start());
for(int_t i = 0 ; i < a->dim() ; i++){
if(a->shape()[i] > 1)
kernel_.setSizeArg(current_arg_++, a->stride()[i]);
}
}
}
void set_arguments(tree_node const & lhs_rhs, bool is_assigned) const
{
switch(lhs_rhs.subtype)
{
case VALUE_SCALAR_TYPE: return set_arguments(lhs_rhs.dtype, lhs_rhs.vscalar);
case DENSE_ARRAY_TYPE: return set_arguments(lhs_rhs.array, is_assigned);
case FOR_LOOP_INDEX_TYPE: return;
default: throw std::runtime_error("Unrecognized type family");
}
}
void operator()(isaac::expression_tree const & expression_tree, size_t root_idx, leaf_t leaf_t) const
{
expression_tree::node const & root_node = expression_tree.tree()[root_idx];
if (leaf_t==LHS_NODE_TYPE && root_node.lhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.lhs, detail::is_assignment(root_node.op));
else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.subtype != COMPOSITE_OPERATOR_TYPE)
set_arguments(root_node.rhs, false);
}
private:
symbolic_binder & binder_;
unsigned int & current_arg_;
driver::Kernel & kernel_;
};
inline void set_arguments(expression_tree const & expression, driver::Kernel & kernel, unsigned int & current_arg, binding_policy_t binding_policy)
{
std::unique_ptr<symbolic_binder> binder;
if (binding_policy==BIND_SEQUENTIAL)
binder.reset(new bind_sequential());
else
binder.reset(new bind_independent());
traverse(expression, expression.root(), set_arguments_functor(*binder, current_arg, kernel), true);
}
}
}