[ir][instruction] added identifier for each instruction

This commit is contained in:
Philippe Tillet
2019-09-19 16:25:36 -04:00
parent 1fd9be27ee
commit e35be1ddcf
26 changed files with 460 additions and 444 deletions

View File

@@ -55,6 +55,8 @@ private:
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
// populate all maps
void populate(ir::value *v);
public:
void run(ir::module &mod);

View File

@@ -23,15 +23,24 @@ class axes {
private:
void add_constraint(node_t x, node_t y);
void init_c_phi(ir::instruction *i);
void init_c_graph(ir::instruction *v);
// update graph
void update_graph_store(ir::instruction *i);
void update_graph_reduce(ir::instruction *i);
void update_graph_reshape(ir::instruction *i);
void update_graph_splat(ir::instruction *i);
void update_graph_trans(ir::instruction *i);
void update_graph_broadcast(ir::instruction *i);
void update_graph_dot(ir::instruction *i);
void update_graph_elementwise(ir::instruction *i);
void update_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
public:
axes();
void run(ir::module &mod);
unsigned get(ir::value *value, unsigned ax);
bool has(ir::value *value, unsigned ax);
unsigned get_id(ir::value *value, unsigned ax);
bool has_id(ir::value *value, unsigned ax);
private:
// constraints graph

View File

@@ -24,8 +24,9 @@ class layout {
typedef std::map <node_t, std::set<node_t>> graph_t;
private:
// create edge
// graph creation
void connect(ir::value *x, ir::value *y);
void make_graph(ir::instruction *i);
// connected components
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
// list the axes of the given value

View File

@@ -1,31 +0,0 @@
#ifndef TDL_INCLUDE_CODEGEN_VECTORIZE_H
#define TDL_INCLUDE_CODEGEN_VECTORIZE_H
namespace triton {
namespace ir {
class module;
}
namespace codegen{
namespace analysis{
class tiles;
}
namespace transform{
class vectorize {
public:
vectorize(analysis::tiles *params): params_(params){}
void run(ir::module &mod);
private:
analysis::tiles *params_;
};
}
}
}
#endif

View File

@@ -140,7 +140,6 @@ public:
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_vectorize(value *arg, const std::string &name = "");
value *create_barrier(const std::string &name = "");
private:

View File

@@ -77,6 +77,67 @@ enum cmp_pred_t {
LAST_ICMP_PREDICATE
};
enum value_id_t: unsigned {
/* ------------ *
INSTRUCTIONS
* ------------ */
INST_BEGIN,
// phi
INST_PHI,
// arithmetic
INST_BINOP,
INST_GETELEMENTPTR,
INST_SELECT,
INST_SQRT,
// cmp
INST_ICMP,
INST_FCMP,
// cast
INST_CAST_TRUNC,
INST_CAST_ZEXT,
INST_CAST_SEXT,
INST_CAST_FP_TRUNC,
INST_CAST_FP_EXT,
INST_CAST_UI_TO_FP,
INST_CAST_SI_TO_FP,
INST_CAST_FP_TO_UI,
INST_CAST_FP_TO_SI,
INST_CAST_PTR_TO_INT,
INST_CAST_INT_TO_PTR,
INST_CAST_BIT_CAST,
INST_CAST_ADDR_SPACE_CAST,
// terminators
INST_RETURN,
INST_COND_BRANCH,
INST_UNCOND_BRANCH,
// io
INST_UNMASKED_LOAD,
INST_MASKED_LOAD,
INST_UNMASKED_STORE,
INST_MASKED_STORE,
// retile
INST_RESHAPE,
INST_SPLAT,
INST_BROADCAST,
INST_DOWNCAST,
// builtin
INST_GET_PROGRAM_ID,
INST_GET_NUM_PROGRAMS,
// atomics
INST_ATOMIC_CAS,
INST_ATOMIC_EXCH,
INST_ATOMIC_ADD,
// array arithmetic
INST_TRANS,
INST_REDUCE,
INST_DOT,
// intrinsics
INST_COPY_TO_SHARED,
INST_BARRIER,
INST_MAKE_RANGE_DYN,
INST_MAKE_RANGE_STA,
INST_MAKE_RANGE
};

View File

@@ -40,7 +40,8 @@ private:
protected:
// constructors
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
instruction(type *ty, value_id_t ity, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// parent
@@ -59,32 +60,21 @@ public:
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
// for(auto it = op_begin(); it != op_end(); it++){
// (*it)->add_use(res);
// }
res->set_name("testcloned");
for(auto it = op_begin(); it != op_end(); it++)
(*it)->add_use(res);
res->parent_ = nullptr;
return res;
}
// instruction id
value_id_t get_id() const { return id_; }
private:
basic_block *parent_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
value_id_t id_;
};
// result reference
class result_reference: public value {
public:
result_reference(instruction *ref, unsigned arg_id, const std::string &name = "");
instruction *get_ref();
unsigned get_arg_id();
private:
instruction *ref_;
unsigned arg_id_;
};
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
@@ -173,11 +163,13 @@ public:
class cmp_inst: public instruction{
public:
typedef cmp_pred_t pred_t;
private:
std::string repr_impl() const;
protected:
cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
static bool is_fp_predicate(cmp_pred_t pred);
static bool is_int_predicate(cmp_pred_t pred);
static type* make_cmp_result_type(type *ty);
@@ -190,7 +182,8 @@ private:
};
class icmp_inst: public cmp_inst {
using cmp_inst::cmp_inst;
icmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
@@ -199,7 +192,8 @@ public:
};
class fcmp_inst: public cmp_inst {
using cmp_inst::cmp_inst;
fcmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next);
public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
@@ -213,7 +207,7 @@ public:
class unary_inst: public instruction {
protected:
unary_inst(type *Ty, value *v, const std::string &name, instruction *next);
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
};
@@ -226,8 +220,8 @@ private:
std::string repr_impl() const;
protected:
cast_inst(type *ty, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, v, name, next), op_(op) { }
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
: unary_inst(ty, id, v, name, next), op_(op) { }
private:
static bool is_valid(cast_op_t op, value *arg, type *ty);
@@ -246,27 +240,27 @@ private:
cast_op_t op_;
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name); \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, v, name, next, op){ } \
: cast_inst(ty, id, v, name, next, op){ } \
};
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, cast_op_t::AddrSpaceCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
//===----------------------------------------------------------------------===//
// terminator_inst classes
@@ -372,33 +366,38 @@ private:
class io_inst: public instruction {
protected:
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
io_inst(type *ty, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
// value *get_mask() const;
// value *get_false_value() const;
};
// load
class load_inst: public io_inst {
protected:
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
load_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
private:
std::string repr_impl() const { return "load"; }
static type *get_pointee_type(type *ty);
public:
// factory method
static load_inst* create(value *ptr,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(load_inst)
};
// unmasked load
class unmasked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "unmasked_load"; }
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
public:
static unmasked_load_inst* create(value *ptr,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_load_inst)
};
// masked load
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
@@ -416,22 +415,28 @@ public:
_TRITON_DEFINE_CLONE(masked_load_inst)
};
class store_inst: public io_inst{
// store
class store_inst: public io_inst {
protected:
store_inst(value *ptr, value *v, unsigned num_extra_ops,
const std::string &name, instruction *next);
private:
std::string repr_impl() const { return "store"; }
store_inst(value *ptr, value_id_t id, unsigned num_ops,
const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_value_operand() { return get_operand(1); }
};
// unmasked_store
class unmasked_store_inst: public store_inst{
private:
std::string repr_impl() const { return "unmasked_store"; }
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
public:
// factory method
static store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(store_inst)
static unmasked_store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(unmasked_store_inst)
};
class masked_store_inst: public store_inst{
@@ -458,7 +463,7 @@ public:
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
@@ -690,16 +695,6 @@ public:
instruction *next = nullptr);
};
class vectorize_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "vectorize"; }
_TRITON_DEFINE_CLONE(vectorize_inst)
public:
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant

View File

@@ -4,18 +4,25 @@
#define _TRITON_IR_CFG_H_
#include <vector>
#include <functional>
namespace triton{
namespace ir{
class module;
class function;
class basic_block;
class instruction;
class value;
class cfg {
public:
static std::vector<basic_block *> reverse_post_order(function* fn);
};
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
}
}

View File

@@ -20,7 +20,6 @@
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/membar.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
#include "triton/lang/parser.h"
#include "triton/runtime/arg.h"

View File

@@ -1,4 +1,5 @@
#include "triton/codegen/analysis/align.h"
#include "triton/ir/utils.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
@@ -14,15 +15,19 @@ namespace analysis{
inline int gcd(int a, int b) {
if (a == 0)
return b;
if (b == 0)
return a;
if (a == b)
return a;
if (a > b)
return gcd(a-b, b);
return gcd(a, b-a);
if (a == 0)
return b;
if (b == 0)
return a;
if (a == b)
return a;
if (a > b)
return gcd(a - b, b);
return gcd(a, b - a);
}
inline int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
template<class T>
@@ -64,8 +69,8 @@ std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* op = x->get_operand(0);
std::vector<cst_info> result;
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
@@ -478,28 +483,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
return max_contiguous_.at(v);
}
void align::populate(ir::value *v) {
populate_is_constant(v);
populate_starting_multiple(v);
populate_max_contiguous(v);
}
void align::run(ir::module &mod) {
// populate constant
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_is_constant(i);
}
// populate starting multiple
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_starting_multiple(i);
}
// populate maximum contiguous
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
}
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
}

View File

@@ -1,5 +1,6 @@
#include "triton/codegen/analysis/axes.h"
#include "triton/ir/instructions.h"
#include "triton/ir/utils.h"
#include "triton/ir/type.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
@@ -30,91 +31,113 @@ void axes::add_constraint(node_t x, node_t y) {
nodes_.insert(y);
}
void axes::init_c_graph(ir::instruction *v) {
// Reference shape
ir::type::tile_shapes_t shapes;
if(auto *store = dynamic_cast<ir::store_inst*>(v))
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
else if(dynamic_cast<ir::downcast_inst*>(v))
void axes::update_graph_reduce(ir::instruction *i) {
auto* red = static_cast<ir::reduce_inst*>(i);
unsigned axis = red->get_axis();
ir::value *arg = red->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned d = 0; d < in_shapes.size(); d++){
if(d == axis)
continue;
add_constraint({i, current++}, {arg, d});
}
}
void axes::update_graph_reshape(ir::instruction *i) {
auto* reshape = static_cast<ir::reshape_inst*>(i);
// operands
ir::value *op = reshape->get_operand(0);
// shapes
auto op_shapes = op->get_type()->get_tile_shapes();
auto res_shapes = reshape->get_type()->get_tile_shapes();
// construct edges
unsigned current = 0;
bool is_skewed = false;
for(unsigned d = 0; d < res_shapes.size(); d ++){
bool same_shape = res_shapes[d] == op_shapes[current];
// either add edge between axis or just add a node in the graph
if(!is_skewed && same_shape)
add_constraint({i, d}, {op, current++});
else
add_constraint({i, d}, {i, d});
// reshaping is skewed
if(res_shapes[d] > 1 && !same_shape)
is_skewed = true;
}
}
void axes::update_graph_splat(ir::instruction *) {
// argument is scalar so don't make any edge
return;
}
void axes::update_graph_trans(ir::instruction *i) {
auto *trans = static_cast<ir::trans_inst*>(i);
ir::value *op = trans->get_operand(0);
auto perm = trans->get_perm();
// add edge between axis perm[d] and axis d
for(unsigned d = 0; d < perm.size(); d++)
add_constraint({i, perm[d]->get_value()}, {op, d});
}
void axes::update_graph_broadcast(ir::instruction *i) {
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
auto shapes = broadcast->get_type()->get_tile_shapes();
ir::value *op = broadcast->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
// add edge between non-broadcast axes
for(unsigned d = 0; d < shapes.size(); d ++)
if(op_shapes[d] == shapes[d])
add_constraint({i, d}, {op, d});
}
void axes::update_graph_dot(ir::instruction *i) {
auto *dot = static_cast<ir::dot_inst*>(i);
auto shapes = dot->get_type()->get_tile_shapes();
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);
ir::value *D = dot->get_operand(2);
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
add_constraint({dot, d}, {D, d});
// add edge for batch dimension
for(unsigned d = 2; d < shapes.size(); d++){
add_constraint({dot, d}, {A, d});
add_constraint({dot, d}, {B, d});
}
}
void axes::update_graph_elementwise(ir::instruction *i) {
if(i->get_num_operands() == 0)
return;
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
return;
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
unsigned axis = reduce->get_axis();
ir::value *arg = reduce->get_operand(0);
auto in_shapes = arg->get_type()->get_tile_shapes();
unsigned current = 0;
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
add_constraint({reduce, current++}, {arg, i});
}
ir::value *op = i->get_operand(0);
if(!op->get_type()->is_tile_ty())
return;
auto rank = op->get_type()->get_tile_rank();
for(unsigned d = 0; d < rank; d++)
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
if(!i->get_type()->is_void_ty())
add_constraint({i, d}, {opx, d});
add_constraint({opx, d}, {opy, d});
}
else
shapes = v->get_type()->get_tile_shapes();
// Reshape
if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
unsigned current = 0;
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == 1){
add_constraint({v, i}, {v, i});
}
else if(!is_skewed &&
shapes[i] == op_shapes[current])
add_constraint({v, i}, {op, current++});
else{
is_skewed = true;
add_constraint({v, i}, {v, i});
}
}
}
// Splat
else if(dynamic_cast<ir::splat_inst*>(v)){
return;
}
// Trans
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0);
auto perm = x->get_perm();
for(unsigned i = 0; i < perm.size(); i++)
add_constraint({v, perm[i]->get_value()}, {op, i});
}
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){
ir::value *op = v->get_operand(0);
ir::type *op_ty = op->get_type();
const auto& op_shapes = op_ty->get_tile_shapes();
for(unsigned i = 0; i < shapes.size(); i ++){
if(op_shapes[i] == shapes[i] && v != op)
add_constraint({v, i}, {op, i});
}
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
for(unsigned i = 0; i < shapes.size(); i ++){
std::vector<ir::value*> ops = v->ops();
for(ir::value* op: ops)
add_constraint({v, i}, {op, i});
}
}
void axes::update_graph(ir::instruction *i) {
switch (i->get_id()) {
case ir::INST_REDUCE: return update_graph_reduce(i);
case ir::INST_RESHAPE: return update_graph_reshape(i);
case ir::INST_SPLAT: return update_graph_splat(i);
case ir::INST_TRANS: return update_graph_trans(i);
case ir::INST_BROADCAST: return update_graph_broadcast(i);
case ir::INST_DOT: return update_graph_dot(i);
default: return update_graph_elementwise(i);
}
return;
}
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
@@ -126,12 +149,12 @@ void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &grap
}
}
unsigned axes::get(ir::value *value, unsigned ax) {
unsigned axes::get_id(ir::value *value, unsigned ax) {
unsigned result = groups_.at(value).at(ax);
return result;
}
bool axes::has(ir::value *value, unsigned ax) {
bool axes::has_id(ir::value *value, unsigned ax) {
auto it = groups_.find(value);
if(it == groups_.end())
return false;
@@ -146,15 +169,9 @@ void axes::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
// Create graph
for(ir::function *fn: mod.get_function_list()){
// Build constraints graph
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list())
if(i->has_tile_result_or_op())
init_c_graph(i);
}
// Axes
// make graph
ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); });
// connected components
unsigned group_id = 0;
while(!nodes_.empty())
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);

View File

@@ -4,6 +4,7 @@
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
#include "triton/ir/utils.h"
namespace triton{
namespace codegen{
@@ -20,8 +21,8 @@ std::set<int> layout::axes_of(ir::value *value) {
// create result
std::set<int> result;
for(size_t d = 0; d < rank; d++){
if(axes_->has(value, d))
result.insert(axes_->get(value, d));
if(axes_->has_id(value, d))
result.insert(axes_->get_id(value, d));
}
return result;
}
@@ -74,24 +75,23 @@ void layout::connect(ir::value *x, ir::value *y) {
}
}
void layout::make_graph(ir::instruction *i) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
// run
void layout::run(ir::module &mod) {
nodes_.clear();
dependencies_.clear();
groups_.clear();
values_.clear();
// Create graph
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()) {
for(ir::value* opx: i->ops())
for(ir::value* opy: i->ops()){
connect(i, opx);
connect(opx, opy);
}
}
// Grids
// make graph
ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); });
// connected components
unsigned group_id = 0;
while(!nodes_.empty()){
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);

View File

@@ -82,9 +82,9 @@ void add_copy(ir::value *x, ir::builder &builder) {
}
void meminfo::run(ir::module &mod) {
// shared_.clear();
// refs_.clear();
// double_.clear();
shared_.clear();
refs_.clear();
double_.clear();
// Add shared copies
for(ir::function *fn: mod.get_function_list()){

View File

@@ -43,19 +43,19 @@ bool tiles::hmma(ir::value *value) {
}
int tiles::mts(ir::value *value, unsigned ax) {
return mts_.at(axes_->get(value, ax));
return mts_.at(axes_->get_id(value, ax));
}
int tiles::nts(ir::value *value, unsigned ax) {
return nts_.at(axes_->get(value, ax));
return nts_.at(axes_->get_id(value, ax));
}
int tiles::fpw(ir::value *value, unsigned ax) {
return fpw_.at(axes_->get(value, ax));
return fpw_.at(axes_->get_id(value, ax));
}
int tiles::wpt(ir::value *value, unsigned ax) {
return wpt_.at(axes_->get(value, ax));
return wpt_.at(axes_->get_id(value, ax));
}
std::vector<int> tiles::order(ir::value *v) {
@@ -92,7 +92,7 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(fpw_nm1 != fpw);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
fpw_[axes_->get(i, d)] = fpw[d];
fpw_[axes_->get_id(i, d)] = fpw[d];
/* warps per tile */
// try to make things as square as possible to maximize data re-use
std::vector<unsigned> wpt = {1, 1, 1};
@@ -106,11 +106,11 @@ void tiles::init_hmma_tile(ir::value *i) {
}while(wpt_nm1 != wpt);
// store parameters
for(unsigned d = 0; d < shapes.size(); d++)
wpt_[axes_->get(i, d)] = wpt[d];
wpt_[axes_->get_id(i, d)] = wpt[d];
/* sanity check */
unsigned effective_num_warps = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_warps *= wpt_[axes_->get(i, d)];
effective_num_warps *= wpt_[axes_->get_id(i, d)];
if(num_warps_ != effective_num_warps)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}
@@ -122,19 +122,19 @@ void tiles::init_scanline_tile(ir::value *i) {
unsigned ld = ord[0];
unsigned num_threads = num_warps_*32;
unsigned current = num_threads;
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
current = current / mts_[axes_->get(i, ld)];
nts_[axes_->get_id(i, ld)] = clamp(size / num_threads, 1, 4);
mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get_id(i, ld)]);
current = current / mts_[axes_->get_id(i, ld)];
for(size_t d = 1; d < shapes.size(); d++){
ld = ord[d];
nts_[axes_->get(i, ld)] = 1;
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get(i, ld)];
nts_[axes_->get_id(i, ld)] = 1;
mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld]);
current = current / mts_[axes_->get_id(i, ld)];
}
/* sanity check */
unsigned effective_num_threads = 1;
for(size_t d = 0; d < shapes.size(); d++)
effective_num_threads *= mts_[axes_->get(i, d)];
effective_num_threads *= mts_[axes_->get_id(i, d)];
if(num_threads != effective_num_threads)
throw std::runtime_error("cannot create a kernel with this amount of warps");
}

View File

@@ -615,7 +615,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
}
}
@@ -720,10 +720,10 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
/* axes */
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[a_axes_->get_id(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[a_axes_->get_id(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
if(is_batched)
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
axes_[a_axes_->get_id(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
}
@@ -791,7 +791,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
std::vector<distributed_axis> axes(shapes.size());
for(size_t d = 0; d < shapes.size(); d++){
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
unsigned x = a_axes_->get_id(v, d);
axes[d] = axes_.at(x);
}
else{
@@ -942,7 +942,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
for(auto& x: partial) {
// current element being computed
Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id;
Value *lane = axes_.at(a_axes_->get_id(op, axis)).thread_id;
Value *&result = x.second;
indices_t write_idx = x.first;
write_idx.insert(write_idx.begin() + axis, lane);

View File

@@ -103,39 +103,9 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
// bool z_order = true;
// if(z_order && ax < 2){
// static std::array<Intrinsic::ID, 3> n_cta_ids = {
// Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_z
// };
// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// // global block ID
// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// // helper for minimum
// auto Min = [&](Value *x, Value *y){
// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
// };
// // super-tile size
// Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// // number of CTAs per super-block
// Value *nscta = builder.CreateMul(n_cta_id_0, sts);
// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
// if(ax == 0)
// return bid0;
// else
// return bid1;
// }
// else{
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
// }
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {

View File

@@ -2,7 +2,7 @@
#include <algorithm>
#include <numeric>
#include "triton/ir/function.h"
#include "triton/ir/cfg.h"
#include "triton/ir/utils.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/module.h"

View File

@@ -1,7 +1,7 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/cfg.h"
#include "triton/ir/utils.h"
#include "triton/codegen/transform/dce.h"
#include <iostream>

View File

@@ -9,7 +9,7 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/cfg.h"
#include "triton/ir/utils.h"
namespace triton {

View File

@@ -6,7 +6,7 @@
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/cfg.h"
#include "triton/ir/utils.h"
namespace triton {
namespace codegen{

View File

@@ -1,41 +0,0 @@
#include "triton/codegen/transform/vectorize.h"
#include "triton/codegen/analysis/tiles.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
namespace triton {
namespace codegen{
namespace transform{
void vectorize::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(auto *trans = dynamic_cast<ir::trans_inst*>(i)){
ir::value *x = i->get_operand(0);
if(trans->get_perm()[0]->get_value() != 0)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
}
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
ir::value *x = i->get_operand(0);
if(params_->nts(x, 0) == 1)
continue;
builder.set_insert_point(i);
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
x->replace_all_uses_with(rx);
rx->set_operand(0, x);
}
}
}
}
}
}

View File

@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -252,11 +252,11 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
//===----------------------------------------------------------------------===//
value *builder::create_load(value *ptr, const std::string &name){
return insert(load_inst::create(ptr, name));
return insert(unmasked_load_inst::create(ptr, name));
}
value *builder::create_store(value *ptr, value *val, const std::string &name){
return insert(store_inst::create(ptr, val, name));
return insert(unmasked_store_inst::create(ptr, val, name));
}
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){
@@ -340,10 +340,6 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
return insert(copy_to_shared_inst::create(arg, name));
}
value *builder::create_vectorize(value *arg, const std::string &name) {
return insert(vectorize_inst::create(arg, name));
}
value *builder::create_barrier(const std::string &name) {
return insert(barrier_inst::create(ctx_, name));
}

View File

@@ -1,31 +0,0 @@
#include <stack>
#include "triton/ir/cfg.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
namespace triton{
namespace ir{
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
std::stack<basic_block*> stack;
std::set<basic_block*> visited;
std::vector<basic_block*> result;
// initialize stack
for(ir::basic_block* block: fn->blocks())
if(block->get_predecessors().empty())
stack.push(block);
// DFS
while(!stack.empty()) {
basic_block* current = stack.top();
stack.pop();
result.push_back(current);
visited.insert(current);
for(basic_block* succ: current->get_successors())
if(visited.find(succ) == visited.end())
stack.push(succ);
}
return std::move(result);
}
}
}

View File

@@ -12,8 +12,9 @@ namespace ir{
// instruction classes
//===----------------------------------------------------------------------===//
instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
: user(ty, num_ops, name) {
instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
const std::string &name, instruction *next)
: user(ty, num_ops, name), id_(ity) {
if(next){
basic_block *block = next->get_parent();
assert(block && "Next instruction is not in a basic block!");
@@ -35,17 +36,12 @@ bool instruction::has_tile_result_or_op() {
return result;
}
// result reference
result_reference::result_reference(instruction *ref, unsigned arg_id, const std::string &name)
: value(ref->get_type(), name), arg_id_(arg_id){ }
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
: instruction(ty, 0, 1, name, next) {
: instruction(ty, INST_PHI, 0, name, next) {
blocks_.reserve(num_reserved);
}
@@ -131,7 +127,7 @@ bool binary_operator::is_int_add_sub() const {
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, 1, name, next), op_(op){
: instruction(ty, INST_BINOP, 2, name, next), op_(op){
set_operand(0, lhs);
set_operand(1, rhs);
}
@@ -164,6 +160,8 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
// cmp_inst classes
//===----------------------------------------------------------------------===//
// cmp_inst
std::string cmp_inst::repr_impl() const {
switch (pred_) {
@@ -197,8 +195,8 @@ std::string cmp_inst::repr_impl() const {
}
}
cmp_inst::cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, 2, 1, name, next), pred_(pred) {
cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, id, 2, name, next), pred_(pred) {
set_operand(0, lhs);
set_operand(1, rhs);
}
@@ -219,7 +217,12 @@ bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
}
// icmp_inst
icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next)
: cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_int_predicate(pred));
type *res_ty = make_cmp_result_type(lhs->get_type());
@@ -227,6 +230,10 @@ icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std:
}
// fcmp_inst
fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
value *lhs, value *rhs, const std::string &name, instruction *next)
: cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
assert(is_fp_predicate(pred));
type *res_ty = make_cmp_result_type(lhs->get_type());
@@ -237,8 +244,8 @@ fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std:
// unary_inst classes
//===----------------------------------------------------------------------===//
unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next)
: instruction(ty, 1, 1, name, next) {
unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
: instruction(ty, id, 1, name, next) {
set_operand(0, v);
}
@@ -309,7 +316,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(type::get_void_ty(ctx), ret_val!=nullptr, 0, "", next){
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
if(ret_val)
set_operand(0, ret_val);
}
@@ -332,13 +339,13 @@ branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *
// uncond_branch_inst
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
: branch_inst(type::get_void_ty(dst->get_context()), 1, 0, "", next){
: branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
set_operand(0, dst);
}
// cond_branch_inst
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
: branch_inst(type::get_void_ty(if_dst->get_context()), 3, 0, "", next){
: branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
set_operand(0, if_dst);
set_operand(1, else_dst);
@@ -351,7 +358,7 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
//===----------------------------------------------------------------------===//
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
: instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next),
: instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
source_elt_ty(pointee_ty),
res_elt_ty(get_indexed_type(pointee_ty, idx)){
// sanity check
@@ -414,8 +421,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
//===----------------------------------------------------------------------===//
// io_inst
io_inst::io_inst(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
: instruction(ty, num_ops, num_results, name, next)
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: instruction(ty, id, num_ops, name, next)
{ }
// load_inst
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next)
{ }
// load
@@ -427,19 +439,21 @@ type *load_inst::get_pointee_type(type *ty) {
return pointee_ty;
}
load_inst::load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next)
: io_inst(get_pointee_type(ptr->get_type()), 1 + num_extra_ops, 1, name, next) {
// unmasked_load
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next)
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) {
set_operand(0, ptr);
}
load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) {
return new load_inst(ptr, 0, name, next);
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) {
return new unmasked_load_inst(ptr, name, next);
}
// masked load
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
const std::string &name, instruction *next)
: load_inst(ptr, 2, name, next) {
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) {
set_operand(0, ptr);
set_operand(1, mask);
set_operand(2, false_value);
}
@@ -450,23 +464,29 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
}
// store
store_inst::store_inst(value *ptr, value *val, unsigned num_extra_ops,
const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), 2 + num_extra_ops, 1, name, next) {
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
{ }
// unmasked_store
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
const std::string &name, instruction *next)
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
store_inst* store_inst::create(value *ptr, value *val,
const std::string &name, instruction *next) {
return new store_inst(ptr, val, 0, name, next);
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
const std::string &name, instruction *next) {
return new unmasked_store_inst(ptr, val, name, next);
}
// masked store
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
const std::string &name, instruction *next)
: store_inst(ptr, val, 1, name, next) {
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
set_operand(0, ptr);
set_operand(1, val);
set_operand(2, mask);
}
@@ -477,15 +497,16 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
// retile_inst classes
//===----------------------------------------------------------------------===//
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
retile_inst::retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next)
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
// reshape
instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) {
return new reshape_inst(arg, shapes, name, next);
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
}
@@ -493,20 +514,20 @@ instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) {
return new splat_inst(arg, shapes, name, next);
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
}
// broadcast
instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
const std::string &name, instruction *next) {
return new broadcast_inst(arg, shapes, name, next);
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
}
// downcast
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next);
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
}
//===----------------------------------------------------------------------===//
@@ -515,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) {
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) {
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
@@ -578,7 +599,7 @@ std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector
}
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
// sanity check
perm_ = init_perm(arg->get_type(), perm);
//auto size = arg->get_type()->get_tile_shapes().size();
@@ -599,7 +620,7 @@ const std::vector<constant_int*> trans_inst::get_perm() const {
//===----------------------------------------------------------------------===//
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
: builtin_inst(arg->get_type(), 1, 1, name, next){
: builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
set_operand(0, arg);
}
@@ -621,7 +642,7 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
}
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
axis_(axis){
set_operand(0, arg);
}
@@ -636,7 +657,7 @@ instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &n
//===----------------------------------------------------------------------===//
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
: builtin_inst(if_value->get_type(), 3, 1, name, next){
: builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
set_operand(0, pred);
set_operand(1, if_value);
set_operand(2, else_value);
@@ -652,7 +673,7 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
// get_program_id
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
: builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
}
@@ -662,7 +683,7 @@ instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std:
// get_num_program
get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
}
@@ -674,7 +695,7 @@ instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std
// atomic cas
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) {
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
set_operand(0, ptr);
set_operand(1, cmp);
set_operand(2, val);
@@ -687,7 +708,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
// atomic exch
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
@@ -699,7 +720,7 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
// atomic add
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
set_operand(0, ptr);
set_operand(1, val);
}
@@ -714,18 +735,13 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
// copy to shared
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
instruction *next) {
return new copy_to_shared_inst(arg->get_type(), arg, name, next);
}
// vectorize
vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, instruction *next) {
return new vectorize_inst(arg->get_type(), arg, name, next);
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
}
// barrier
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), 0, 0, name, next) { }
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
return new barrier_inst(ctx, name, next);
@@ -734,7 +750,7 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
// nv_dynamic_program_idx
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
: instruction(ty, 0, 1, name, next) { }
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
return new make_range_dyn(ty, name, next);
@@ -757,7 +773,7 @@ make_range_sta* make_range_sta::get(make_range* range) {
// make_range
make_range::make_range(type *ty, constant_int *first, constant_int *last)
: instruction(ty, 0), first_(first), last_(last){ }
: instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
make_range *make_range::create(constant_int *first, constant_int *last) {
assert(first->get_type()->is_integer_ty());

54
lib/ir/utils.cc Normal file
View File

@@ -0,0 +1,54 @@
#include <stack>
#include <iostream>
#include "triton/ir/utils.h"
#include "triton/ir/basic_block.h"
#include "triton/ir/function.h"
#include "triton/ir/module.h"
namespace triton{
namespace ir{
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
std::stack<basic_block*> stack;
std::set<basic_block*> visited;
std::vector<basic_block*> result;
// initialize stack
for(ir::basic_block* block: fn->blocks())
if(block->get_predecessors().empty())
stack.push(block);
// DFS
while(!stack.empty()) {
basic_block* current = stack.top();
stack.pop();
result.push_back(current);
visited.insert(current);
for(basic_block* succ: current->get_successors())
if(visited.find(succ) == visited.end())
stack.push(succ);
}
return std::move(result);
}
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn))
for(ir::instruction *i: block->get_inst_list())
do_work(i);
}
void for_each_value(module &mod, const std::function<void (value *)> &do_work) {
std::set<ir::value*> seen;
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: cfg::reverse_post_order(fn))
for(ir::instruction *i: block->get_inst_list()){
for(ir::value *op: i->ops()){
if(seen.insert(op).second)
do_work(op);
}
if(seen.insert(i).second)
do_work(i);
}
}
}
}