Files
triton/lib/jit/syntax/engine/object.cpp
2017-01-17 20:33:46 -05:00

357 lines
12 KiB
C++

/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <string>
#include "isaac/array.h"
#include "isaac/exception/api.h"
#include "isaac/jit/syntax/engine/object.h"
#include "isaac/jit/syntax/expression/expression.h"
#include "isaac/tools/cpp/string.hpp"
namespace isaac
{
namespace symbolic
{
//
void object::add_base(const std::string &name)
{ hierarchy_.push_front(name); }
object::object(driver::Context const & context, std::string const & scalartype, unsigned int id): object(context, scalartype, "obj" + tools::to_string(id))
{}
void object::add_load(bool contiguous)
{
macros_.insert("loadv(i): at(i)");
macros_.insert("loadv(i,j): at(i,j)");
driver::backend_type backend = context_.backend();
if(contiguous && backend==driver::OPENCL)
{
macros_.insert("loadv2(i): vload2(0, &at(i))");
macros_.insert("loadv2(i,j): vload2(0, &at(i,j))");
macros_.insert("loadv4(i): vload4(0, &at(i))");
macros_.insert("loadv4(i,j): vload4(0, &at(i,j))");
}
else
{
auto prefix = [&](std::string const & w){ return (backend==driver::OPENCL)?"(#scalartype"+w+")":"make_#scalartype"+w; };
std::string prefix2 = prefix("2"), prefix4 = prefix("4");
macros_.insert("loadv2(i): " + prefix2 + "(at(i), at(i+1))");
macros_.insert("loadv2(i,j): " + prefix2 + "(at(i,j), at(i+1,j))");
macros_.insert("loadv4(i): " + prefix4 + "(at(i), at(i+1), at(i+2), at(i+3))");
macros_.insert("loadv4(i,j): " + prefix4 + "(at(i,j), at(i+1,j), at(i+2,j), at(i+3,j))");
}
}
object::object(driver::Context const & context, std::string const & scalartype, std::string const & name): context_(context)
{
add_base("object");
//attributes
attributes_["scalartype"] = scalartype;
attributes_["name"] = name;
}
object::~object()
{}
std::string object::process(std::string const & in) const
{
std::string res = in;
//Macros
bool modified;
do{
modified = false;
for (auto const & key : macros_)
modified = modified || key.expand(res);
}while(modified);
//Attributes
for (auto const & key : attributes_)
tools::find_and_replace(res, "#" + key.first, key.second);
return res;
}
bool object::hasattr(std::string const & name) const
{
return attributes_.find(name) != attributes_.end();
}
std::string object::evaluate(std::map<std::string, std::string> const & table) const
{
for(std::string const & type: hierarchy_)
for (auto const& supplied : table )
if(type==supplied.first)
return process(supplied.second);
throw "NOT FOUND";
}
//
leaf::leaf(driver::Context const & context, std::string const & scalartype, unsigned int id): object(context, scalartype, id)
{ add_base("leaf"); }
leaf::leaf(driver::Context const & context, std::string const & scalartype, std::string const & name): object(context, scalartype, name)
{ add_base("leaf"); }
//
node::node(size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : op_(op), lhs_(NULL), rhs_(NULL), root_(root)
{
expression_tree::node const & node = tree[root];
symbols_table::const_iterator it;
if((it = table.find(node.binary_operator.lhs))!=table.end())
lhs_ = it->second.get();
if((it = table.find(node.binary_operator.rhs))!=table.end())
rhs_ = it->second.get();
}
op_element node::op() const
{ return op_; }
object const * node::lhs() const
{ return lhs_; }
object const * node::rhs() const
{ return rhs_; }
size_t node::root() const
{ return root_; }
//
sfor::sfor(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : object(tree.context(), to_string(tree[root].dtype), id), node(root, op, tree, table)
{ add_base("sfor"); }
//
arithmetic_node::arithmetic_node(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : object(tree.context(), to_string(tree[root].dtype), id), node(root, op, tree, table), op_str_(to_string(op.type))
{ }
//
binary_arithmetic_node::binary_arithmetic_node(unsigned int id, size_t root, op_element op, expression_tree const & expression, symbols_table const & table) : arithmetic_node(id, root, op, expression, table)
{ add_base("binary_arithmetic_node"); }
std::string binary_arithmetic_node::evaluate(std::map<std::string, std::string> const & table) const
{
std::string arg0 = lhs_->evaluate(table);
std::string arg1 = rhs_->evaluate(table);
if(is_function(op_.type))
return op_str_ + "(" + arg0 + ", " + arg1 + ")";
else
return "(" + arg0 + op_str_ + arg1 + ")";
}
//
unary_arithmetic_node::unary_arithmetic_node(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) :
arithmetic_node(id, root, op, tree, table)
{ add_base("unary_arithmetic_node"); }
std::string unary_arithmetic_node::evaluate(std::map<std::string, std::string> const & table) const
{ return op_str_ + "(" + lhs_->evaluate(table) + ")"; }
//
reduction::reduction(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) :
object(tree.context(), to_string(tree[root].dtype), id), node(root, op, tree, table)
{ add_base("reduction"); }
//
reduce_1d::reduce_1d(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : reduction(id, root, op, tree, table)
{ add_base("reduce_1d"); }
//
reduce_2d::reduce_2d(unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : reduction(id, root, op, tree, table)
{ add_base("reduce_2d"); }
//
placeholder::placeholder(driver::Context const & context, unsigned int level) : leaf(context, "int", "sforidx" + tools::to_string(level))
{
macros_.insert("at(): #name");
macros_.insert("at(i): #name");
macros_.insert("at(i,j): #name");
add_base("placebolder");
add_load(false);
}
//
host_scalar::host_scalar(driver::Context const & context, std::string const & scalartype, unsigned int id) : leaf(context, scalartype, id)
{
macros_.insert("at(): #name_value");
macros_.insert("at(i): #name_value");
macros_.insert("at(i,j): #name_value");
add_base("host_scalar");
add_load(false);
}
//
array::array(driver::Context const & context, std::string const & scalartype, unsigned int id) : leaf(context, scalartype, id)
{
attributes_["pointer"] = process("#name_pointer");
add_base("array");
}
std::string array::make_broadcast(const tuple &shape)
{
std::string result = "at(";
for(size_t i = 0 ; i < shape.size() ; ++i)
result += ((result.back()=='(')?"arg":",arg") + tools::to_string(i);
result += ") : at(";
for(size_t i = 0 ; i < shape.size() ; ++i)
if(shape[i] > 1)
result += ((result.back()=='(')?"arg":",arg") + tools::to_string(i);
result += ")";
return result;
}
//
buffer::buffer(driver::Context const & context, std::string const & scalartype, unsigned int id, const tuple &shape, tuple const & strides) : array(context, scalartype, id), dim_(numgt1(shape))
{
//Attributes
attributes_["off"] = process("#name_off");
for(unsigned int i = 0 ; i < dim_ ; ++i){
std::string inc = "inc" + tools::to_string(i);
attributes_[inc] = process("#name_" + inc);
}
//Access
std::vector<std::string> args;
for(unsigned int i = 0 ; i < dim_ ; ++i)
args.push_back("x" + tools::to_string(i));
std::string off = "#off";
for(unsigned int i = 0 ; i < dim_ ; ++i)
{
std::string inc = "#inc" + tools::to_string(i);
off += " + (" + args[i] + ")*" + inc;
}
macros_.insert("at(" + tools::join(args, ",") + "): #pointer[" + off + "]");
//Broadcast
if(numgt1(shape)==0)
macros_.insert("at(i): at()");
if(dim_!=shape.size())
macros_.insert(make_broadcast(shape));
add_base("buffer");
add_load(strides[0]==1 && shape[0]>1);
}
//
index_modifier::index_modifier(const std::string &scalartype, unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : array(tree.context(), scalartype, id), node(root, op, tree, table)
{
add_base("index_modifier");
add_load(false);
}
//Reshaping
reshape::reshape(std::string const & scalartype, unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : index_modifier(scalartype, id, root, op, tree, table)
{
add_base("reshape");
tuple new_shape = tree[root].shape;
tuple old_shape = tree[tree[root].binary_operator.lhs].shape;
//Attributes
for(unsigned int i = 1 ; i < new_shape.size() ; ++i)
if(new_shape[i] > 1){
std::string inc = "new_inc" + tools::to_string(i);
attributes_[inc] = process("#name_" + inc);
}
for(unsigned int i = 1 ; i < old_shape.size() ; ++i)
if(old_shape[i] > 1){
std::string inc = "old_inc" + tools::to_string(i);
attributes_[inc] = process("#name_" + inc);
}
//Index modification
size_t new_gt1 = numgt1(new_shape);
size_t old_gt1 = numgt1(old_shape);
if(new_gt1==0 && old_gt1==0)
macros_.insert("at(): " + lhs_->evaluate({{"leaf","at()"}}));
if(new_gt1==0 && old_gt1==1)
macros_.insert("at(): " + lhs_->evaluate({{"leaf","at(0)"}}));
if(new_gt1==0 && old_gt1==2)
macros_.insert("at(): " + lhs_->evaluate({{"leaf","at(0,0)"}}));
if(new_gt1==1 && old_gt1==0)
macros_.insert("at(i): " + lhs_->evaluate({{"leaf","at()"}}));
if(new_gt1==1 && old_gt1==1)
macros_.insert("at(i): " + lhs_->evaluate({{"leaf","at(i)"}}));
if(new_gt1==1 && old_gt1==2)
macros_.insert("at(i): " + lhs_->evaluate({{"leaf","at((i)%#old_inc1, (i)/#old_inc1)"}}));
if(new_gt1==2 && old_gt1==0)
macros_.insert("at(i,j): " + lhs_-> evaluate({{"leaf","at()"}}));
if(new_gt1==2 && old_gt1==1)
macros_.insert("at(i,j): " + lhs_-> evaluate({{"leaf","at((i) + (j)*#new_inc1)"}}));
if(new_gt1==2 && old_gt1==2)
macros_.insert("at(i,j): " + lhs_->evaluate({{"leaf","at(((i) + (j)*#new_inc1)%#old_inc1, ((i)+(j)*#new_inc1)/#old_inc1)"}}));
//Broadcast
if(numgt1(new_shape)==0)
macros_.insert("at(i): at()");
if(new_gt1!=new_shape.size())
macros_.insert(make_broadcast(new_shape));
}
//Transposition
trans::trans(std::string const & scalartype, unsigned int id, size_t root, op_element op, expression_tree const & tree, symbols_table const & table) : index_modifier(scalartype, id, root, op, tree, table)
{
add_base("trans");
tuple shape = tree[root].shape;
std::vector<std::string> args;
for(unsigned int i = 0 ; i < numgt1(shape) ; ++i)
args.push_back("x" + tools::to_string(i));
std::vector<std::string> rotated = args;
if(rotated.size()>1)
std::rotate(rotated.begin(), rotated.end()-1, rotated.end());
macros_.insert("at(" + tools::join(args, ",") + "): " + lhs_->evaluate({{"leaf", "at(" + tools::join(rotated, ",") + ")"}}));
//Broadcast
if(numgt1(shape)==0)
macros_.insert("at(i): at()");
if(numgt1(shape)!=shape.size())
macros_.insert(make_broadcast(shape));
}
//
diag_vector::diag_vector(const std::string &scalartype, unsigned int id, size_t root, op_element op, const expression_tree &tree, const symbols_table &table) : index_modifier(scalartype, id, root, op, tree, table)
{
add_base("diag_vector");
macros_.insert("at(i,j): " + lhs_->evaluate({{"leaf","(i==j)?at(i):0"}}));
tuple const & shape = tree[root].shape;
if(numgt1(shape)!=shape.size())
macros_.insert(make_broadcast(shape));
}
////
}
}