Files
triton/lib/jit/syntax/expression/expression.cpp
2017-01-17 20:33:46 -05:00

201 lines
6.8 KiB
C++

/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <cassert>
#include <vector>
#include "isaac/array.h"
#include "isaac/value_scalar.h"
#include "isaac/exception/api.h"
#include "isaac/jit/syntax/expression/expression.h"
#include "isaac/jit/syntax/expression/preset.h"
#include "isaac/tools/cpp/string.hpp"
namespace isaac
{
//
expression_tree::node::node(){}
//Constructors
expression_tree::node::node(invalid_node) : type(INVALID_SUBTYPE), dtype(INVALID_NUMERIC_TYPE)
{}
expression_tree::node::node(value_scalar const & x) : type(VALUE_SCALAR_TYPE), dtype(x.dtype()), shape{1}, scalar(x.values())
{}
expression_tree::node::node(array_base const & x) : type(DENSE_ARRAY_TYPE), dtype(x.dtype()), shape(x.shape())
{
array.start = x.start();
array.base = (array_base*)&x;
driver::Buffer::handle_type const & h = x.data().handle();
switch(h.backend()){
case driver::OPENCL: array.handle.cl = h.cl(); break;
case driver::CUDA: array.handle.cu = h.cu(); break;
}
ld = x.stride();
}
expression_tree::node::node(int_t lhs, op_element op, int_t rhs, numeric_type dt, tuple const & sh)
{
type = COMPOSITE_OPERATOR_TYPE;
dtype = dt;
shape = sh;
binary_operator.lhs = lhs;
binary_operator.op = op;
binary_operator.rhs = rhs;
}
//
expression_tree::expression_tree(node const & lhs, node const & rhs, op_element const & op, driver::Context const * context, numeric_type const & dtype, tuple const & shape) :
root_(2), context_(context)
{
tree_.reserve(3);
tree_.push_back(std::move(lhs));
tree_.push_back(std::move(rhs));
tree_.emplace_back(node(0, op, 1, dtype, shape));
}
expression_tree::expression_tree(expression_tree const & lhs, node const & rhs, op_element const & op, driver::Context const * context, numeric_type const & dtype, tuple const & shape) :
tree_(lhs.tree_.size() + 2), root_(tree_.size() - 1), context_(context)
{
std::move(lhs.tree_.begin(), lhs.tree_.end(), tree_.begin());
tree_[root_ - 1] = rhs;
tree_[root_] = node(lhs.root_, op, root_ - 1, dtype, shape);
}
expression_tree::expression_tree(node const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const * context, numeric_type const & dtype, tuple const & shape) :
tree_(rhs.tree_.size() + 2), root_(tree_.size() - 1), context_(context)
{
std::move(rhs.tree_.begin(), rhs.tree_.end(), tree_.begin());
tree_[root_ - 1] = lhs;
tree_[root_] = node(root_ - 1, op, rhs.root_, dtype, shape);
}
expression_tree::expression_tree(expression_tree const & lhs, expression_tree const & rhs, op_element const & op, driver::Context const * context, numeric_type const & dtype, tuple const & shape):
tree_(lhs.tree_.size() + rhs.tree_.size() + 1), root_(tree_.size()-1), context_(context)
{
std::size_t lsize = lhs.tree_.size();
std::move(lhs.tree_.begin(), lhs.tree_.end(), tree_.begin());
std::move(rhs.tree_.begin(), rhs.tree_.end(), tree_.begin() + lsize);
tree_[root_] = node(lhs.root_, op, lsize + rhs.root_, dtype, shape);
for(data_type::iterator it = tree_.begin() + lsize ; it != tree_.end() - 1 ; ++it){
if(it->type==COMPOSITE_OPERATOR_TYPE){
it->binary_operator.lhs += lsize;
it->binary_operator.rhs += lsize;
}
}
}
expression_tree::data_type const & expression_tree::data() const
{ return tree_; }
std::size_t expression_tree::root() const
{ return root_; }
driver::Context const & expression_tree::context() const
{ return *context_; }
numeric_type const & expression_tree::dtype() const
{ return tree_[root_].dtype; }
tuple expression_tree::shape() const
{ return tree_[root_].shape; }
int_t expression_tree::dim() const
{ return (int_t)shape().size(); }
expression_tree expression_tree::operator-()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_ARITHMETIC, SUB_TYPE), context_, dtype(), shape()); }
expression_tree expression_tree::operator!()
{ return expression_tree(*this, invalid_node(), op_element(UNARY_ARITHMETIC, NEGATE_TYPE), context_, INT_TYPE, shape()); }
expression_tree::node const & expression_tree::operator[](size_t idx) const
{ return tree_[idx]; }
expression_tree::node & expression_tree::operator[](size_t idx)
{ return tree_[idx]; }
//io
#define ISAAC_MAP_TO_STRING(NAME) case NAME: return #NAME
inline std::string to_string(const op_element& op)
{
std::string res = to_string(op.type);
if(op.type_family==REDUCE) res = "reduce<" + res + ">";
if(op.type_family==REDUCE_ROWS) res = "reduce<" + res + ", rows>";
if(op.type_family==REDUCE_COLUMNS) res = "reduce<" + res + ", cols>";
return res;
}
inline std::string to_string(const expression_tree::node &node)
{
if(node.type==COMPOSITE_OPERATOR_TYPE)
{
std::string lhs = tools::to_string(node.binary_operator.lhs);
std::string op = to_string(node.binary_operator.op);
std::string rhs = tools::to_string(node.binary_operator.rhs);
return"node (" + lhs + ", " + op + ", " + rhs + ")";
}
switch(node.type)
{
case INVALID_SUBTYPE:
return "empty";
case VALUE_SCALAR_TYPE:
return "scalar";
case DENSE_ARRAY_TYPE:
return "array";
default:
return "unknown";
}
}
namespace detail
{
/** @brief Recursive worker routine for printing a whole expression_tree */
inline void print_node(std::ostream & os, isaac::expression_tree const & s, size_t index, size_t indent = 0)
{
expression_tree::data_type const & data = s.data();
expression_tree::node const & node = data[index];
for (size_t i=0; i<indent; ++i)
os << " ";
os << "Node " << index << ": " << to_string(node) << std::endl;
if (node.type == COMPOSITE_OPERATOR_TYPE)
{
print_node(os, s, node.binary_operator.lhs, indent+1);
print_node(os, s, node.binary_operator.rhs, indent+1);
}
}
}
std::string to_string(isaac::expression_tree const & s)
{
std::ostringstream os;
detail::print_node(os, s, s.root());
return os.str();
}
}