Files
triton/lib/jit/syntax/engine/process.cpp
2016-04-10 13:13:16 -04:00

234 lines
7.8 KiB
C++

/*
* Copyright (c) 2015, PHILIPPE TILLET. All rights reserved.
*
* This file is part of ISAAC.
*
* ISAAC is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
* MA 02110-1301 USA
*/
#include "isaac/jit/syntax/engine/process.h"
namespace isaac
{
namespace symbolic
{
// Filter nodes
std::vector<size_t> find(expression_tree const & tree, size_t root, std::function<bool (expression_tree::node const &)> const & pred)
{
std::vector<size_t> result;
auto fun = [&](size_t index) { if(pred(tree[index])) result.push_back(index); };
traverse(tree, root, fun);
return result;
}
std::vector<size_t> find(expression_tree const & tree, std::function<bool (expression_tree::node const &)> const & pred)
{
return find(tree, tree.root(), pred);
}
std::vector<size_t> assignments(expression_tree const & tree)
{
return find(tree, [&](expression_tree::node const & node)
{return node.type==COMPOSITE_OPERATOR_TYPE && is_assignment(node.binary_operator.op.type);}
);
}
std::vector<size_t> lhs_of(expression_tree const & tree, std::vector<size_t> const & in)
{
std::vector<size_t> result;
for(size_t idx: in)
result.push_back(tree[idx].binary_operator.lhs);
return result;
}
std::vector<size_t> rhs_of(expression_tree const & tree, std::vector<size_t> const & in)
{
std::vector<size_t> result;
for(size_t idx: in)
result.push_back(tree[idx].binary_operator.rhs);
return result;
}
// Hash
std::string hash(expression_tree const & tree)
{
driver::backend_type backend = tree.context().backend();
char program_name[256];
char* ptr = program_name;
bind_independent binder(backend);
auto hash_impl = [&](size_t idx)
{
expression_tree::node const & node = tree.data()[idx];
if(node.type==DENSE_ARRAY_TYPE)
{
for(size_t i = 0 ; i < node.shape.size() ; ++i)
*ptr++= node.shape[i]>1?'n':'1';
if(node.ld[0]>1) *ptr++= 's';
*ptr++=(char)node.dtype;
tools::fast_append(ptr, binder.get(node.array.handle, false));
}
else if(node.type==COMPOSITE_OPERATOR_TYPE)
tools::fast_append(ptr,node.binary_operator.op.type);
};
traverse(tree, hash_impl);
*ptr='\0';
return std::string(program_name);
}
//Set arguments
void set_arguments(expression_tree const & tree, driver::Kernel & kernel, unsigned int & current_arg, fusion_policy_t fusion_policy)
{
driver::backend_type backend = tree.context().backend();
//Create binder
std::unique_ptr<symbolic_binder> binder;
if (fusion_policy==FUSE_SEQUENTIAL)
binder.reset(new bind_sequential(backend));
else
binder.reset(new bind_independent(backend));
//assigned
std::vector<size_t> assignee = symbolic::find(tree, [&](expression_tree::node const & node){return node.type==COMPOSITE_OPERATOR_TYPE && is_assignment(node.binary_operator.op.type);});
for(size_t& x: assignee) x = tree[x].binary_operator.lhs;
//set_arguments_impl
auto set_arguments_impl = [&](size_t index)
{
expression_tree::node const & node = tree.data()[index];
if(node.type==VALUE_SCALAR_TYPE)
kernel.setArg(current_arg++,value_scalar(node.scalar,node.dtype));
else if(node.type==DENSE_ARRAY_TYPE)
{
array_holder const & array = node.array;
bool is_assigned = std::find(assignee.begin(), assignee.end(), index)!=assignee.end();
bool is_bound = binder->bind(array.handle, is_assigned);
if (is_bound)
{
if(backend==driver::OPENCL)
kernel.setArg(current_arg++, array.handle.cl);
else
kernel.setArg(current_arg++, array.handle.cu);
kernel.setSizeArg(current_arg++, array.start);
for(size_t i = 0 ; i < node.shape.size() ; i++)
{
if(node.shape[i] > 1)
kernel.setSizeArg(current_arg++, node.ld[i]);
}
}
}
else if(node.type==COMPOSITE_OPERATOR_TYPE && node.binary_operator.op.type == RESHAPE_TYPE)
{
tuple const & new_shape = node.shape;
int_t current = 1;
for(size_t i = 1 ; i < new_shape.size() ; ++i){
current *= new_shape[i-1];
if(new_shape[i] > 1)
kernel.setSizeArg(current_arg++, current);
}
tuple const & old_shape = tree.data()[node.binary_operator.lhs].shape;
current = 1;
for(unsigned int i = 1 ; i < old_shape.size() ; ++i){
current *= old_shape[i-1];
if(old_shape[i] > 1)
kernel.setSizeArg(current_arg++, current);
}
}
};
//Traverse
traverse(tree, set_arguments_impl);
}
//Symbolize
template<class T, class... Args>
std::shared_ptr<object> make_symbolic(Args&&... args)
{
return std::shared_ptr<object>(new T(std::forward<Args>(args)...));
}
symbols_table symbolize(fusion_policy_t fusion_policy, isaac::expression_tree const & tree)
{
driver::Context const & context = tree.context();
//binder
symbols_table table;
std::unique_ptr<symbolic_binder> binder;
if (fusion_policy==FUSE_SEQUENTIAL)
binder.reset(new bind_sequential(context.backend()));
else
binder.reset(new bind_independent(context.backend()));
//assigned
std::vector<size_t> assignee = symbolic::find(tree, [&](expression_tree::node const & node){return node.type==COMPOSITE_OPERATOR_TYPE && is_assignment(node.binary_operator.op.type);});
for(size_t& x: assignee) x = tree[x].binary_operator.lhs;
//symbolize_impl
auto symbolize_impl = [&](size_t root)
{
expression_tree::node const & node = tree.data()[root];
std::string dtype = to_string(node.dtype);
if(node.type==VALUE_SCALAR_TYPE)
table.insert({root, make_symbolic<host_scalar>(context, dtype, binder->get())});
else if(node.type==DENSE_ARRAY_TYPE){
bool is_assigned = std::find(assignee.begin(), assignee.end(), root)!=assignee.end();
table.insert({root, make_symbolic<buffer>(context, dtype, binder->get(node.array.handle, is_assigned), node.shape, node.ld)});
}
else if(node.type==PLACEHOLDER_TYPE)
table.insert({root, make_symbolic<placeholder>(context, node.ph.level)});
else if(node.type==COMPOSITE_OPERATOR_TYPE)
{
unsigned int id = binder->get();
op_element op = node.binary_operator.op;
//Index modifier
if(op.type==RESHAPE_TYPE)
table.insert({root, make_symbolic<reshape>(dtype, id, root, op, tree, table)});
else if(op.type==TRANS_TYPE)
table.insert({root, make_symbolic<trans>(dtype, id, root, op, tree, table)});
else if(op.type==DIAG_VECTOR_TYPE)
table.insert({root, make_symbolic<diag_vector>(dtype, id, root, op, tree, table)});
//Unary arithmetic
else if(op.type_family==UNARY_ARITHMETIC)
table.insert({root, make_symbolic<unary_arithmetic_node>(id, root, op, tree, table)});
//Binary arithmetic
else if(op.type_family==BINARY_ARITHMETIC)
table.insert({root, make_symbolic<binary_arithmetic_node>(id, root, op, tree, table)});
//1D Reduction
else if (op.type_family==REDUCE)
table.insert({root, make_symbolic<reduce_1d>(id, root, op, tree, table)});
//2D reduction
else if (op.type_family==REDUCE_ROWS || op.type_family==REDUCE_COLUMNS)
table.insert({root, make_symbolic<reduce_2d>(id, root, op, tree, table)});
}
};
//traverse
traverse(tree, symbolize_impl);
return table;
}
}
}