[ir][instruction] added identifier for each instruction
This commit is contained in:
@@ -55,6 +55,8 @@ private:
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
// populate all maps
|
||||
void populate(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
|
@@ -23,15 +23,24 @@ class axes {
|
||||
|
||||
private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
// update graph
|
||||
void update_graph_store(ir::instruction *i);
|
||||
void update_graph_reduce(ir::instruction *i);
|
||||
void update_graph_reshape(ir::instruction *i);
|
||||
void update_graph_splat(ir::instruction *i);
|
||||
void update_graph_trans(ir::instruction *i);
|
||||
void update_graph_broadcast(ir::instruction *i);
|
||||
void update_graph_dot(ir::instruction *i);
|
||||
void update_graph_elementwise(ir::instruction *i);
|
||||
void update_graph(ir::instruction *i);
|
||||
// connected components
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
|
||||
public:
|
||||
axes();
|
||||
void run(ir::module &mod);
|
||||
unsigned get(ir::value *value, unsigned ax);
|
||||
bool has(ir::value *value, unsigned ax);
|
||||
unsigned get_id(ir::value *value, unsigned ax);
|
||||
bool has_id(ir::value *value, unsigned ax);
|
||||
|
||||
private:
|
||||
// constraints graph
|
||||
|
@@ -24,8 +24,9 @@ class layout {
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
private:
|
||||
// create edge
|
||||
// graph creation
|
||||
void connect(ir::value *x, ir::value *y);
|
||||
void make_graph(ir::instruction *i);
|
||||
// connected components
|
||||
void connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned id);
|
||||
// list the axes of the given value
|
||||
|
@@ -1,31 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_CODEGEN_VECTORIZE_H
|
||||
#define TDL_INCLUDE_CODEGEN_VECTORIZE_H
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
namespace analysis{
|
||||
class tiles;
|
||||
}
|
||||
|
||||
namespace transform{
|
||||
|
||||
class vectorize {
|
||||
public:
|
||||
vectorize(analysis::tiles *params): params_(params){}
|
||||
void run(ir::module &mod);
|
||||
|
||||
private:
|
||||
analysis::tiles *params_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -140,7 +140,6 @@ public:
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
// Intrinsics
|
||||
value *create_copy_to_shared(value *arg, const std::string &name = "");
|
||||
value *create_vectorize(value *arg, const std::string &name = "");
|
||||
value *create_barrier(const std::string &name = "");
|
||||
|
||||
private:
|
||||
|
@@ -77,6 +77,67 @@ enum cmp_pred_t {
|
||||
LAST_ICMP_PREDICATE
|
||||
};
|
||||
|
||||
enum value_id_t: unsigned {
|
||||
/* ------------ *
|
||||
INSTRUCTIONS
|
||||
* ------------ */
|
||||
INST_BEGIN,
|
||||
// phi
|
||||
INST_PHI,
|
||||
// arithmetic
|
||||
INST_BINOP,
|
||||
INST_GETELEMENTPTR,
|
||||
INST_SELECT,
|
||||
INST_SQRT,
|
||||
// cmp
|
||||
INST_ICMP,
|
||||
INST_FCMP,
|
||||
// cast
|
||||
INST_CAST_TRUNC,
|
||||
INST_CAST_ZEXT,
|
||||
INST_CAST_SEXT,
|
||||
INST_CAST_FP_TRUNC,
|
||||
INST_CAST_FP_EXT,
|
||||
INST_CAST_UI_TO_FP,
|
||||
INST_CAST_SI_TO_FP,
|
||||
INST_CAST_FP_TO_UI,
|
||||
INST_CAST_FP_TO_SI,
|
||||
INST_CAST_PTR_TO_INT,
|
||||
INST_CAST_INT_TO_PTR,
|
||||
INST_CAST_BIT_CAST,
|
||||
INST_CAST_ADDR_SPACE_CAST,
|
||||
// terminators
|
||||
INST_RETURN,
|
||||
INST_COND_BRANCH,
|
||||
INST_UNCOND_BRANCH,
|
||||
// io
|
||||
INST_UNMASKED_LOAD,
|
||||
INST_MASKED_LOAD,
|
||||
INST_UNMASKED_STORE,
|
||||
INST_MASKED_STORE,
|
||||
// retile
|
||||
INST_RESHAPE,
|
||||
INST_SPLAT,
|
||||
INST_BROADCAST,
|
||||
INST_DOWNCAST,
|
||||
// builtin
|
||||
INST_GET_PROGRAM_ID,
|
||||
INST_GET_NUM_PROGRAMS,
|
||||
// atomics
|
||||
INST_ATOMIC_CAS,
|
||||
INST_ATOMIC_EXCH,
|
||||
INST_ATOMIC_ADD,
|
||||
// array arithmetic
|
||||
INST_TRANS,
|
||||
INST_REDUCE,
|
||||
INST_DOT,
|
||||
// intrinsics
|
||||
INST_COPY_TO_SHARED,
|
||||
INST_BARRIER,
|
||||
INST_MAKE_RANGE_DYN,
|
||||
INST_MAKE_RANGE_STA,
|
||||
INST_MAKE_RANGE
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
@@ -40,7 +40,8 @@ private:
|
||||
|
||||
protected:
|
||||
// constructors
|
||||
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
|
||||
instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// parent
|
||||
@@ -59,32 +60,21 @@ public:
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
// for(auto it = op_begin(); it != op_end(); it++){
|
||||
// (*it)->add_use(res);
|
||||
// }
|
||||
res->set_name("testcloned");
|
||||
for(auto it = op_begin(); it != op_end(); it++)
|
||||
(*it)->add_use(res);
|
||||
res->parent_ = nullptr;
|
||||
return res;
|
||||
}
|
||||
// instruction id
|
||||
value_id_t get_id() const { return id_; }
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
value_id_t id_;
|
||||
};
|
||||
|
||||
|
||||
// result reference
|
||||
class result_reference: public value {
|
||||
public:
|
||||
result_reference(instruction *ref, unsigned arg_id, const std::string &name = "");
|
||||
instruction *get_ref();
|
||||
unsigned get_arg_id();
|
||||
|
||||
private:
|
||||
instruction *ref_;
|
||||
unsigned arg_id_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -173,11 +163,13 @@ public:
|
||||
class cmp_inst: public instruction{
|
||||
public:
|
||||
typedef cmp_pred_t pred_t;
|
||||
|
||||
private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
static bool is_fp_predicate(cmp_pred_t pred);
|
||||
static bool is_int_predicate(cmp_pred_t pred);
|
||||
static type* make_cmp_result_type(type *ty);
|
||||
@@ -190,7 +182,8 @@ private:
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst {
|
||||
using cmp_inst::cmp_inst;
|
||||
icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
@@ -199,7 +192,8 @@ public:
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst {
|
||||
using cmp_inst::cmp_inst;
|
||||
fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
@@ -213,7 +207,7 @@ public:
|
||||
|
||||
class unary_inst: public instruction {
|
||||
protected:
|
||||
unary_inst(type *Ty, value *v, const std::string &name, instruction *next);
|
||||
unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
|
||||
@@ -226,8 +220,8 @@ private:
|
||||
std::string repr_impl() const;
|
||||
|
||||
protected:
|
||||
cast_inst(type *ty, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, v, name, next), op_(op) { }
|
||||
cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
|
||||
: unary_inst(ty, id, v, name, next), op_(op) { }
|
||||
|
||||
private:
|
||||
static bool is_valid(cast_op_t op, value *arg, type *ty);
|
||||
@@ -246,27 +240,27 @@ private:
|
||||
cast_op_t op_;
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name); \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, v, name, next, op){ } \
|
||||
: cast_inst(ty, id, v, name, next, op){ } \
|
||||
};
|
||||
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, cast_op_t::AddrSpaceCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
|
||||
TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// terminator_inst classes
|
||||
@@ -372,33 +366,38 @@ private:
|
||||
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
|
||||
io_inst(type *ty, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
|
||||
// value *get_mask() const;
|
||||
// value *get_false_value() const;
|
||||
};
|
||||
|
||||
// load
|
||||
class load_inst: public io_inst {
|
||||
protected:
|
||||
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
|
||||
load_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
private:
|
||||
std::string repr_impl() const { return "load"; }
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
public:
|
||||
|
||||
// factory method
|
||||
static load_inst* create(value *ptr,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(load_inst)
|
||||
};
|
||||
|
||||
// unmasked load
|
||||
class unmasked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_load"; }
|
||||
unmasked_load_inst(value *ptr, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static unmasked_load_inst* create(value *ptr,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_load_inst)
|
||||
};
|
||||
|
||||
// masked load
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load"; }
|
||||
@@ -416,22 +415,28 @@ public:
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
};
|
||||
|
||||
class store_inst: public io_inst{
|
||||
// store
|
||||
class store_inst: public io_inst {
|
||||
protected:
|
||||
store_inst(value *ptr, value *v, unsigned num_extra_ops,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
private:
|
||||
std::string repr_impl() const { return "store"; }
|
||||
store_inst(value *ptr, value_id_t id, unsigned num_ops,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
};
|
||||
|
||||
// unmasked_store
|
||||
class unmasked_store_inst: public store_inst{
|
||||
private:
|
||||
std::string repr_impl() const { return "unmasked_store"; }
|
||||
unmasked_store_inst(value *ptr, value *v, const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(store_inst)
|
||||
static unmasked_store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(unmasked_store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
@@ -458,7 +463,7 @@ public:
|
||||
|
||||
class retile_inst: public unary_inst {
|
||||
protected:
|
||||
retile_inst(value *arg, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
|
||||
};
|
||||
|
||||
// reshape
|
||||
@@ -690,16 +695,6 @@ public:
|
||||
instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class vectorize_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "vectorize"; }
|
||||
_TRITON_DEFINE_CLONE(vectorize_inst)
|
||||
|
||||
public:
|
||||
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
// so as to enable re-association on nv_static_program_idx which is constant
|
||||
|
@@ -4,18 +4,25 @@
|
||||
#define _TRITON_IR_CFG_H_
|
||||
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
class module;
|
||||
class function;
|
||||
class basic_block;
|
||||
class instruction;
|
||||
class value;
|
||||
|
||||
class cfg {
|
||||
public:
|
||||
static std::vector<basic_block *> reverse_post_order(function* fn);
|
||||
};
|
||||
|
||||
void for_each_instruction(ir::module& mod, const std::function<void(triton::ir::instruction*)> &fn);
|
||||
void for_each_value(ir::module& mod, const std::function<void(triton::ir::value *)> &fn);
|
||||
|
||||
}
|
||||
}
|
||||
|
@@ -20,7 +20,6 @@
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/membar.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/runtime/arg.h"
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
@@ -14,15 +15,19 @@ namespace analysis{
|
||||
|
||||
|
||||
inline int gcd(int a, int b) {
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a-b, b);
|
||||
return gcd(a, b-a);
|
||||
if (a == 0)
|
||||
return b;
|
||||
if (b == 0)
|
||||
return a;
|
||||
if (a == b)
|
||||
return a;
|
||||
if (a > b)
|
||||
return gcd(a - b, b);
|
||||
return gcd(a, b - a);
|
||||
}
|
||||
|
||||
inline int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@@ -64,8 +69,8 @@ std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* op = x->get_operand(0);
|
||||
std::vector<cst_info> result;
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(auto d: shapes)
|
||||
result.push_back(cst_info{d, op_cst[0].value});
|
||||
@@ -478,28 +483,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
|
||||
void align::populate(ir::value *v) {
|
||||
populate_is_constant(v);
|
||||
populate_starting_multiple(v);
|
||||
populate_max_contiguous(v);
|
||||
}
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
|
||||
// populate constant
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_is_constant(i);
|
||||
}
|
||||
|
||||
// populate starting multiple
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_starting_multiple(i);
|
||||
}
|
||||
|
||||
// populate maximum contiguous
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
}
|
||||
ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -30,91 +31,113 @@ void axes::add_constraint(node_t x, node_t y) {
|
||||
nodes_.insert(y);
|
||||
}
|
||||
|
||||
void axes::init_c_graph(ir::instruction *v) {
|
||||
// Reference shape
|
||||
ir::type::tile_shapes_t shapes;
|
||||
if(auto *store = dynamic_cast<ir::store_inst*>(v))
|
||||
shapes = store->get_pointer_operand()->get_type()->get_tile_shapes();
|
||||
else if(auto *atom = dynamic_cast<ir::atomic_add_inst*>(v))
|
||||
shapes = atom->get_operand(0)->get_type()->get_tile_shapes();
|
||||
else if(dynamic_cast<ir::downcast_inst*>(v))
|
||||
|
||||
void axes::update_graph_reduce(ir::instruction *i) {
|
||||
auto* red = static_cast<ir::reduce_inst*>(i);
|
||||
unsigned axis = red->get_axis();
|
||||
ir::value *arg = red->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned d = 0; d < in_shapes.size(); d++){
|
||||
if(d == axis)
|
||||
continue;
|
||||
add_constraint({i, current++}, {arg, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_reshape(ir::instruction *i) {
|
||||
auto* reshape = static_cast<ir::reshape_inst*>(i);
|
||||
// operands
|
||||
ir::value *op = reshape->get_operand(0);
|
||||
// shapes
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto res_shapes = reshape->get_type()->get_tile_shapes();
|
||||
// construct edges
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned d = 0; d < res_shapes.size(); d ++){
|
||||
bool same_shape = res_shapes[d] == op_shapes[current];
|
||||
// either add edge between axis or just add a node in the graph
|
||||
if(!is_skewed && same_shape)
|
||||
add_constraint({i, d}, {op, current++});
|
||||
else
|
||||
add_constraint({i, d}, {i, d});
|
||||
// reshaping is skewed
|
||||
if(res_shapes[d] > 1 && !same_shape)
|
||||
is_skewed = true;
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_splat(ir::instruction *) {
|
||||
// argument is scalar so don't make any edge
|
||||
return;
|
||||
}
|
||||
|
||||
void axes::update_graph_trans(ir::instruction *i) {
|
||||
auto *trans = static_cast<ir::trans_inst*>(i);
|
||||
ir::value *op = trans->get_operand(0);
|
||||
auto perm = trans->get_perm();
|
||||
// add edge between axis perm[d] and axis d
|
||||
for(unsigned d = 0; d < perm.size(); d++)
|
||||
add_constraint({i, perm[d]->get_value()}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_broadcast(ir::instruction *i) {
|
||||
auto *broadcast = static_cast<ir::broadcast_inst*>(i);
|
||||
auto shapes = broadcast->get_type()->get_tile_shapes();
|
||||
ir::value *op = broadcast->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
// add edge between non-broadcast axes
|
||||
for(unsigned d = 0; d < shapes.size(); d ++)
|
||||
if(op_shapes[d] == shapes[d])
|
||||
add_constraint({i, d}, {op, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_dot(ir::instruction *i) {
|
||||
auto *dot = static_cast<ir::dot_inst*>(i);
|
||||
auto shapes = dot->get_type()->get_tile_shapes();
|
||||
ir::value *A = dot->get_operand(0);
|
||||
ir::value *B = dot->get_operand(1);
|
||||
ir::value *D = dot->get_operand(2);
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
add_constraint({dot, d}, {D, d});
|
||||
// add edge for batch dimension
|
||||
for(unsigned d = 2; d < shapes.size(); d++){
|
||||
add_constraint({dot, d}, {A, d});
|
||||
add_constraint({dot, d}, {B, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
else if(dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
return;
|
||||
else if(auto *reduce = dynamic_cast<ir::reduce_inst*>(v)) {
|
||||
unsigned axis = reduce->get_axis();
|
||||
ir::value *arg = reduce->get_operand(0);
|
||||
auto in_shapes = arg->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
for(unsigned i = 0; i < in_shapes.size(); i++){
|
||||
if(i == axis)
|
||||
continue;
|
||||
add_constraint({reduce, current++}, {arg, i});
|
||||
}
|
||||
ir::value *op = i->get_operand(0);
|
||||
if(!op->get_type()->is_tile_ty())
|
||||
return;
|
||||
auto rank = op->get_type()->get_tile_rank();
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
add_constraint({i, d}, {opx, d});
|
||||
add_constraint({opx, d}, {opy, d});
|
||||
}
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == 1){
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed &&
|
||||
shapes[i] == op_shapes[current])
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
is_skewed = true;
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
// Splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||
return;
|
||||
}
|
||||
// Trans
|
||||
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto perm = x->get_perm();
|
||||
for(unsigned i = 0; i < perm.size(); i++)
|
||||
add_constraint({v, perm[i]->get_value()}, {op, i});
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
ir::type *op_ty = op->get_type();
|
||||
const auto& op_shapes = op_ty->get_tile_shapes();
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(op_shapes[i] == shapes[i] && v != op)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
// Matrix multiplication
|
||||
else if(dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *A = v->get_operand(0);
|
||||
ir::value *B = v->get_operand(1);
|
||||
ir::value *D = v->get_operand(2);
|
||||
for(unsigned i = 0; i < shapes.size(); i++)
|
||||
add_constraint({v, i}, {D, i});
|
||||
for(unsigned i = 2; i < shapes.size(); i++){
|
||||
add_constraint({v, i}, {A, i});
|
||||
add_constraint({v, i}, {B, i});
|
||||
}
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void axes::update_graph(ir::instruction *i) {
|
||||
switch (i->get_id()) {
|
||||
case ir::INST_REDUCE: return update_graph_reduce(i);
|
||||
case ir::INST_RESHAPE: return update_graph_reshape(i);
|
||||
case ir::INST_SPLAT: return update_graph_splat(i);
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
@@ -126,12 +149,12 @@ void axes::connected_components(node_t x, std::set<node_t> &nodes, graph_t &grap
|
||||
}
|
||||
}
|
||||
|
||||
unsigned axes::get(ir::value *value, unsigned ax) {
|
||||
unsigned axes::get_id(ir::value *value, unsigned ax) {
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
|
||||
bool axes::has(ir::value *value, unsigned ax) {
|
||||
bool axes::has_id(ir::value *value, unsigned ax) {
|
||||
auto it = groups_.find(value);
|
||||
if(it == groups_.end())
|
||||
return false;
|
||||
@@ -146,15 +169,9 @@ void axes::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
// Create graph
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
// Build constraints graph
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list())
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_graph(i);
|
||||
}
|
||||
// Axes
|
||||
// make graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction *x) { update_graph(x); });
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty())
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
|
@@ -4,6 +4,7 @@
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
@@ -20,8 +21,8 @@ std::set<int> layout::axes_of(ir::value *value) {
|
||||
// create result
|
||||
std::set<int> result;
|
||||
for(size_t d = 0; d < rank; d++){
|
||||
if(axes_->has(value, d))
|
||||
result.insert(axes_->get(value, d));
|
||||
if(axes_->has_id(value, d))
|
||||
result.insert(axes_->get_id(value, d));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -74,24 +75,23 @@ void layout::connect(ir::value *x, ir::value *y) {
|
||||
}
|
||||
}
|
||||
|
||||
void layout::make_graph(ir::instruction *i) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
connect(i, opx);
|
||||
connect(opx, opy);
|
||||
}
|
||||
}
|
||||
|
||||
// run
|
||||
void layout::run(ir::module &mod) {
|
||||
nodes_.clear();
|
||||
dependencies_.clear();
|
||||
groups_.clear();
|
||||
values_.clear();
|
||||
// Create graph
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i : block->get_inst_list()) {
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
connect(i, opx);
|
||||
connect(opx, opy);
|
||||
}
|
||||
|
||||
}
|
||||
// Grids
|
||||
// make graph
|
||||
ir::for_each_instruction(mod, [this](ir::instruction* i) { make_graph(i); });
|
||||
// connected components
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()){
|
||||
connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++);
|
||||
|
@@ -82,9 +82,9 @@ void add_copy(ir::value *x, ir::builder &builder) {
|
||||
}
|
||||
|
||||
void meminfo::run(ir::module &mod) {
|
||||
// shared_.clear();
|
||||
// refs_.clear();
|
||||
// double_.clear();
|
||||
shared_.clear();
|
||||
refs_.clear();
|
||||
double_.clear();
|
||||
|
||||
// Add shared copies
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
|
@@ -43,19 +43,19 @@ bool tiles::hmma(ir::value *value) {
|
||||
}
|
||||
|
||||
int tiles::mts(ir::value *value, unsigned ax) {
|
||||
return mts_.at(axes_->get(value, ax));
|
||||
return mts_.at(axes_->get_id(value, ax));
|
||||
}
|
||||
|
||||
int tiles::nts(ir::value *value, unsigned ax) {
|
||||
return nts_.at(axes_->get(value, ax));
|
||||
return nts_.at(axes_->get_id(value, ax));
|
||||
}
|
||||
|
||||
int tiles::fpw(ir::value *value, unsigned ax) {
|
||||
return fpw_.at(axes_->get(value, ax));
|
||||
return fpw_.at(axes_->get_id(value, ax));
|
||||
}
|
||||
|
||||
int tiles::wpt(ir::value *value, unsigned ax) {
|
||||
return wpt_.at(axes_->get(value, ax));
|
||||
return wpt_.at(axes_->get_id(value, ax));
|
||||
}
|
||||
|
||||
std::vector<int> tiles::order(ir::value *v) {
|
||||
@@ -92,7 +92,7 @@ void tiles::init_hmma_tile(ir::value *i) {
|
||||
}while(fpw_nm1 != fpw);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
fpw_[axes_->get(i, d)] = fpw[d];
|
||||
fpw_[axes_->get_id(i, d)] = fpw[d];
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
std::vector<unsigned> wpt = {1, 1, 1};
|
||||
@@ -106,11 +106,11 @@ void tiles::init_hmma_tile(ir::value *i) {
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
wpt_[axes_->get(i, d)] = wpt[d];
|
||||
wpt_[axes_->get_id(i, d)] = wpt[d];
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_warps *= wpt_[axes_->get(i, d)];
|
||||
effective_num_warps *= wpt_[axes_->get_id(i, d)];
|
||||
if(num_warps_ != effective_num_warps)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
@@ -122,19 +122,19 @@ void tiles::init_scanline_tile(ir::value *i) {
|
||||
unsigned ld = ord[0];
|
||||
unsigned num_threads = num_warps_*32;
|
||||
unsigned current = num_threads;
|
||||
nts_[axes_->get(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get(i, ld)]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
nts_[axes_->get_id(i, ld)] = clamp(size / num_threads, 1, 4);
|
||||
mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld] / nts_[axes_->get_id(i, ld)]);
|
||||
current = current / mts_[axes_->get_id(i, ld)];
|
||||
for(size_t d = 1; d < shapes.size(); d++){
|
||||
ld = ord[d];
|
||||
nts_[axes_->get(i, ld)] = 1;
|
||||
mts_[axes_->get(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[axes_->get(i, ld)];
|
||||
nts_[axes_->get_id(i, ld)] = 1;
|
||||
mts_[axes_->get_id(i, ld)] = clamp(current, 1, shapes[ld]);
|
||||
current = current / mts_[axes_->get_id(i, ld)];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
effective_num_threads *= mts_[axes_->get(i, d)];
|
||||
effective_num_threads *= mts_[axes_->get_id(i, d)];
|
||||
if(num_threads != effective_num_threads)
|
||||
throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
@@ -615,7 +615,7 @@ void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[a_axes_->get(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -720,10 +720,10 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[a_axes_->get(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[a_axes_->get(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
axes_[a_axes_->get_id(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[a_axes_->get_id(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[a_axes_->get(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
axes_[a_axes_->get_id(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
@@ -791,7 +791,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) {
|
||||
std::vector<distributed_axis> axes(shapes.size());
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
unsigned x = a_axes_->get_id(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -942,7 +942,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||
for(auto& x: partial) {
|
||||
// current element being computed
|
||||
Value *lane = axes_.at(a_axes_->get(op, axis)).thread_id;
|
||||
Value *lane = axes_.at(a_axes_->get_id(op, axis)).thread_id;
|
||||
Value *&result = x.second;
|
||||
indices_t write_idx = x.first;
|
||||
write_idx.insert(write_idx.begin() + axis, lane);
|
||||
|
@@ -103,39 +103,9 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
// bool z_order = true;
|
||||
// if(z_order && ax < 2){
|
||||
// static std::array<Intrinsic::ID, 3> n_cta_ids = {
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
// Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
// };
|
||||
// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
|
||||
// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
|
||||
// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
|
||||
// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
|
||||
// // global block ID
|
||||
// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
|
||||
// // helper for minimum
|
||||
// auto Min = [&](Value *x, Value *y){
|
||||
// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
|
||||
// };
|
||||
// // super-tile size
|
||||
// Value* sts = Min(builder.getInt32(16), n_cta_id_1);
|
||||
// // number of CTAs per super-block
|
||||
// Value *nscta = builder.CreateMul(n_cta_id_0, sts);
|
||||
// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
|
||||
// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
|
||||
// if(ax == 0)
|
||||
// return bid0;
|
||||
// else
|
||||
// return bid1;
|
||||
// }
|
||||
// else{
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
return cta_id;
|
||||
// }
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include <iostream>
|
||||
|
||||
|
@@ -9,7 +9,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
|
@@ -6,7 +6,7 @@
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/utils.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
|
@@ -1,41 +0,0 @@
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
#include "triton/codegen/analysis/tiles.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void vectorize::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(auto *trans = dynamic_cast<ir::trans_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
}
|
||||
if(dynamic_cast<ir::copy_to_shared_inst*>(i)){
|
||||
ir::value *x = i->get_operand(0);
|
||||
if(params_->nts(x, 0) == 1)
|
||||
continue;
|
||||
builder.set_insert_point(i);
|
||||
ir::instruction *rx = (ir::instruction*)builder.create_vectorize(x);
|
||||
x->replace_all_uses_with(rx);
|
||||
rx->set_operand(0, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@@ -241,6 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -252,11 +252,11 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_load(value *ptr, const std::string &name){
|
||||
return insert(load_inst::create(ptr, name));
|
||||
return insert(unmasked_load_inst::create(ptr, name));
|
||||
}
|
||||
|
||||
value *builder::create_store(value *ptr, value *val, const std::string &name){
|
||||
return insert(store_inst::create(ptr, val, name));
|
||||
return insert(unmasked_store_inst::create(ptr, val, name));
|
||||
}
|
||||
|
||||
value *builder::create_masked_load(value *ptr, value *mask, value *false_value, const std::string &name){
|
||||
@@ -340,10 +340,6 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) {
|
||||
return insert(copy_to_shared_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_vectorize(value *arg, const std::string &name) {
|
||||
return insert(vectorize_inst::create(arg, name));
|
||||
}
|
||||
|
||||
value *builder::create_barrier(const std::string &name) {
|
||||
return insert(barrier_inst::create(ctx_, name));
|
||||
}
|
||||
|
@@ -1,31 +0,0 @@
|
||||
#include <stack>
|
||||
#include "triton/ir/cfg.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
std::stack<basic_block*> stack;
|
||||
std::set<basic_block*> visited;
|
||||
std::vector<basic_block*> result;
|
||||
// initialize stack
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
if(block->get_predecessors().empty())
|
||||
stack.push(block);
|
||||
// DFS
|
||||
while(!stack.empty()) {
|
||||
basic_block* current = stack.top();
|
||||
stack.pop();
|
||||
result.push_back(current);
|
||||
visited.insert(current);
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end())
|
||||
stack.push(succ);
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@@ -12,8 +12,9 @@ namespace ir{
|
||||
// instruction classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
|
||||
: user(ty, num_ops, name) {
|
||||
instruction::instruction(type *ty, value_id_t ity, unsigned num_ops,
|
||||
const std::string &name, instruction *next)
|
||||
: user(ty, num_ops, name), id_(ity) {
|
||||
if(next){
|
||||
basic_block *block = next->get_parent();
|
||||
assert(block && "Next instruction is not in a basic block!");
|
||||
@@ -35,17 +36,12 @@ bool instruction::has_tile_result_or_op() {
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
// result reference
|
||||
result_reference::result_reference(instruction *ref, unsigned arg_id, const std::string &name)
|
||||
: value(ref->get_type(), name), arg_id_(arg_id){ }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
|
||||
: instruction(ty, 0, 1, name, next) {
|
||||
: instruction(ty, INST_PHI, 0, name, next) {
|
||||
blocks_.reserve(num_reserved);
|
||||
}
|
||||
|
||||
@@ -131,7 +127,7 @@ bool binary_operator::is_int_add_sub() const {
|
||||
|
||||
|
||||
binary_operator::binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 2, 1, name, next), op_(op){
|
||||
: instruction(ty, INST_BINOP, 2, name, next), op_(op){
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
@@ -164,6 +160,8 @@ binary_operator *binary_operator::create_not(value *arg, const std::string &name
|
||||
// cmp_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
|
||||
// cmp_inst
|
||||
std::string cmp_inst::repr_impl() const {
|
||||
switch (pred_) {
|
||||
@@ -197,8 +195,8 @@ std::string cmp_inst::repr_impl() const {
|
||||
}
|
||||
}
|
||||
|
||||
cmp_inst::cmp_inst(type *ty, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: instruction(ty, 2, 1, name, next), pred_(pred) {
|
||||
cmp_inst::cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 2, name, next), pred_(pred) {
|
||||
set_operand(0, lhs);
|
||||
set_operand(1, rhs);
|
||||
}
|
||||
@@ -219,7 +217,12 @@ bool cmp_inst::is_int_predicate(cmp_pred_t pred) {
|
||||
return pred >= FIRST_ICMP_PREDICATE && pred <= LAST_ICMP_PREDICATE;
|
||||
}
|
||||
|
||||
|
||||
// icmp_inst
|
||||
icmp_inst::icmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_ICMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_int_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
@@ -227,6 +230,10 @@ icmp_inst* icmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std:
|
||||
}
|
||||
|
||||
// fcmp_inst
|
||||
fcmp_inst::fcmp_inst(type *ty, cmp_pred_t pred,
|
||||
value *lhs, value *rhs, const std::string &name, instruction *next)
|
||||
: cmp_inst(ty, INST_FCMP, pred, lhs, rhs, name, next){ }
|
||||
|
||||
fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next){
|
||||
assert(is_fp_predicate(pred));
|
||||
type *res_ty = make_cmp_result_type(lhs->get_type());
|
||||
@@ -237,8 +244,8 @@ fcmp_inst* fcmp_inst::create(cmp_pred_t pred, value *lhs, value *rhs, const std:
|
||||
// unary_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next)
|
||||
: instruction(ty, 1, 1, name, next) {
|
||||
unary_inst::unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, 1, name, next) {
|
||||
set_operand(0, v);
|
||||
}
|
||||
|
||||
@@ -309,7 +316,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
|
||||
|
||||
// return_inst
|
||||
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
|
||||
: terminator_inst(type::get_void_ty(ctx), ret_val!=nullptr, 0, "", next){
|
||||
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
|
||||
if(ret_val)
|
||||
set_operand(0, ret_val);
|
||||
}
|
||||
@@ -332,13 +339,13 @@ branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *
|
||||
|
||||
// uncond_branch_inst
|
||||
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
|
||||
: branch_inst(type::get_void_ty(dst->get_context()), 1, 0, "", next){
|
||||
: branch_inst(type::get_void_ty(dst->get_context()), INST_UNCOND_BRANCH, 1, "", next){
|
||||
set_operand(0, dst);
|
||||
}
|
||||
|
||||
// cond_branch_inst
|
||||
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
|
||||
: branch_inst(type::get_void_ty(if_dst->get_context()), 3, 0, "", next){
|
||||
: branch_inst(type::get_void_ty(if_dst->get_context()), INST_COND_BRANCH, 3, "", next){
|
||||
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
|
||||
set_operand(0, if_dst);
|
||||
set_operand(1, else_dst);
|
||||
@@ -351,7 +358,7 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
|
||||
: instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next),
|
||||
: instruction(get_return_type(pointee_ty, ptr, idx), INST_GETELEMENTPTR, 1 + idx.size(), name, next),
|
||||
source_elt_ty(pointee_ty),
|
||||
res_elt_ty(get_indexed_type(pointee_ty, idx)){
|
||||
// sanity check
|
||||
@@ -414,8 +421,13 @@ getelementptr_inst *getelementptr_inst::create(value *ptr, const std::vector<val
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// io_inst
|
||||
io_inst::io_inst(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
|
||||
: instruction(ty, num_ops, num_results, name, next)
|
||||
io_inst::io_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: instruction(ty, id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// load_inst
|
||||
load_inst::load_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// load
|
||||
@@ -427,19 +439,21 @@ type *load_inst::get_pointee_type(type *ty) {
|
||||
return pointee_ty;
|
||||
}
|
||||
|
||||
load_inst::load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next)
|
||||
: io_inst(get_pointee_type(ptr->get_type()), 1 + num_extra_ops, 1, name, next) {
|
||||
// unmasked_load
|
||||
unmasked_load_inst::unmasked_load_inst(value *ptr, const std::string &name, instruction *next)
|
||||
: load_inst(ptr, INST_UNMASKED_LOAD, 1, name, next) {
|
||||
set_operand(0, ptr);
|
||||
}
|
||||
|
||||
load_inst* load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
return new load_inst(ptr, 0, name, next);
|
||||
unmasked_load_inst* unmasked_load_inst::create(value *ptr, const std::string &name, instruction *next) {
|
||||
return new unmasked_load_inst(ptr, name, next);
|
||||
}
|
||||
|
||||
// masked load
|
||||
masked_load_inst::masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name, instruction *next)
|
||||
: load_inst(ptr, 2, name, next) {
|
||||
: load_inst(ptr, INST_MASKED_LOAD, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, mask);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
@@ -450,23 +464,29 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false
|
||||
}
|
||||
|
||||
|
||||
// store
|
||||
store_inst::store_inst(value *ptr, value *val, unsigned num_extra_ops,
|
||||
const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), 2 + num_extra_ops, 1, name, next) {
|
||||
store_inst::store_inst(value *ptr, value_id_t id, unsigned num_ops, const std::string &name, instruction *next)
|
||||
: io_inst(type::get_void_ty(ptr->get_type()->get_context()), id, num_ops, name, next)
|
||||
{ }
|
||||
|
||||
// unmasked_store
|
||||
unmasked_store_inst::unmasked_store_inst(value *ptr, value *val,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, INST_UNMASKED_STORE, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
|
||||
store_inst* store_inst::create(value *ptr, value *val,
|
||||
const std::string &name, instruction *next) {
|
||||
return new store_inst(ptr, val, 0, name, next);
|
||||
unmasked_store_inst* unmasked_store_inst::create(value *ptr, value *val,
|
||||
const std::string &name, instruction *next) {
|
||||
return new unmasked_store_inst(ptr, val, name, next);
|
||||
}
|
||||
|
||||
// masked store
|
||||
masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
|
||||
const std::string &name, instruction *next)
|
||||
: store_inst(ptr, val, 1, name, next) {
|
||||
: store_inst(ptr, INST_MASKED_STORE, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
set_operand(2, mask);
|
||||
}
|
||||
|
||||
@@ -477,15 +497,16 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask
|
||||
// retile_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
retile_inst::retile_inst(value *arg, const type::tile_shapes_t &shapes,
|
||||
retile_inst::retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next)
|
||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), arg, name, next) { }
|
||||
: unary_inst(tile_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { }
|
||||
|
||||
|
||||
// reshape
|
||||
|
||||
instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new reshape_inst(arg, shapes, name, next);
|
||||
return new reshape_inst(arg, INST_RESHAPE, shapes, name, next);
|
||||
}
|
||||
|
||||
|
||||
@@ -493,20 +514,20 @@ instruction* reshape_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
|
||||
instruction* splat_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new splat_inst(arg, shapes, name, next);
|
||||
return new splat_inst(arg, INST_SPLAT, shapes, name, next);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
|
||||
instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shapes,
|
||||
const std::string &name, instruction *next) {
|
||||
return new broadcast_inst(arg, shapes, name, next);
|
||||
return new broadcast_inst(arg, INST_BROADCAST, shapes, name, next);
|
||||
}
|
||||
|
||||
// downcast
|
||||
|
||||
instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next);
|
||||
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -515,7 +536,7 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
|
||||
|
||||
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
|
||||
const std::string &name, instruction *next)
|
||||
: builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) {
|
||||
: builtin_inst(C->get_type(), INST_DOT, 3, name, next), AT_(AT), BT_(BT) {
|
||||
set_operand(0, A);
|
||||
set_operand(1, B);
|
||||
set_operand(2, C);
|
||||
@@ -578,7 +599,7 @@ std::vector<constant_int*> trans_inst::init_perm(ir::type* ty, const std::vector
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type(), perm), 1, 1, name, next) {
|
||||
: builtin_inst(get_res_ty(arg->get_type(), perm), INST_TRANS, 1, name, next) {
|
||||
// sanity check
|
||||
perm_ = init_perm(arg->get_type(), perm);
|
||||
//auto size = arg->get_type()->get_tile_shapes().size();
|
||||
@@ -599,7 +620,7 @@ const std::vector<constant_int*> trans_inst::get_perm() const {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
sqrt_inst::sqrt_inst(value *arg, const std::string &name, instruction *next)
|
||||
: builtin_inst(arg->get_type(), 1, 1, name, next){
|
||||
: builtin_inst(arg->get_type(), INST_SQRT, 1, name, next){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
@@ -621,7 +642,7 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) {
|
||||
}
|
||||
|
||||
reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_type(arg, axis), 1, 1, name, next),
|
||||
: builtin_inst(get_res_type(arg, axis), INST_REDUCE, 1, name, next),
|
||||
axis_(axis){
|
||||
set_operand(0, arg);
|
||||
}
|
||||
@@ -636,7 +657,7 @@ instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &n
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next)
|
||||
: builtin_inst(if_value->get_type(), 3, 1, name, next){
|
||||
: builtin_inst(if_value->get_type(), INST_SELECT, 3, name, next){
|
||||
set_operand(0, pred);
|
||||
set_operand(1, if_value);
|
||||
set_operand(2, else_value);
|
||||
@@ -652,7 +673,7 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
|
||||
|
||||
// get_program_id
|
||||
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||
: builtin_inst(ty, INST_GET_PROGRAM_ID, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
@@ -662,7 +683,7 @@ instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std:
|
||||
|
||||
// get_num_program
|
||||
get_num_program_inst::get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||
: builtin_inst(ty, INST_GET_NUM_PROGRAMS, 0, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
@@ -674,7 +695,7 @@ instruction* get_num_program_inst::create(context &ctx, unsigned axis, const std
|
||||
// atomic cas
|
||||
|
||||
atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) {
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_CAS, 3, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, cmp);
|
||||
set_operand(2, val);
|
||||
@@ -687,7 +708,7 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s
|
||||
// atomic exch
|
||||
|
||||
atomic_exch_inst::atomic_exch_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_EXCH, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
@@ -699,7 +720,7 @@ instruction* atomic_exch_inst::create(value *ptr, value *val, const std::string
|
||||
// atomic add
|
||||
|
||||
atomic_add_inst::atomic_add_inst(value *ptr, value *val, const std::string &name, instruction *next)
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), 2, 1, name, next) {
|
||||
: builtin_inst(ptr->get_type()->get_pointer_element_ty(), INST_ATOMIC_ADD, 2, name, next) {
|
||||
set_operand(0, ptr);
|
||||
set_operand(1, val);
|
||||
}
|
||||
@@ -714,18 +735,13 @@ instruction* atomic_add_inst::create(value *ptr, value *val, const std::string &
|
||||
// copy to shared
|
||||
copy_to_shared_inst* copy_to_shared_inst::create(value *arg, const std::string &name,
|
||||
instruction *next) {
|
||||
return new copy_to_shared_inst(arg->get_type(), arg, name, next);
|
||||
}
|
||||
|
||||
// vectorize
|
||||
vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new vectorize_inst(arg->get_type(), arg, name, next);
|
||||
return new copy_to_shared_inst(arg->get_type(), INST_COPY_TO_SHARED, arg, name, next);
|
||||
}
|
||||
|
||||
// barrier
|
||||
barrier_inst::barrier_inst(context &ctx, const std::string &name,
|
||||
instruction *next)
|
||||
: instruction(type::get_void_ty(ctx), 0, 0, name, next) { }
|
||||
: instruction(type::get_void_ty(ctx), INST_BARRIER, 0, name, next) { }
|
||||
|
||||
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
|
||||
return new barrier_inst(ctx, name, next);
|
||||
@@ -734,7 +750,7 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
||||
|
||||
// nv_dynamic_program_idx
|
||||
make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 0, 1, name, next) { }
|
||||
: instruction(ty, INST_MAKE_RANGE_DYN, 0, name, next) { }
|
||||
|
||||
make_range_dyn* make_range_dyn::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new make_range_dyn(ty, name, next);
|
||||
@@ -757,7 +773,7 @@ make_range_sta* make_range_sta::get(make_range* range) {
|
||||
|
||||
// make_range
|
||||
make_range::make_range(type *ty, constant_int *first, constant_int *last)
|
||||
: instruction(ty, 0), first_(first), last_(last){ }
|
||||
: instruction(ty, INST_MAKE_RANGE, 0), first_(first), last_(last){ }
|
||||
|
||||
make_range *make_range::create(constant_int *first, constant_int *last) {
|
||||
assert(first->get_type()->is_integer_ty());
|
||||
|
54
lib/ir/utils.cc
Normal file
54
lib/ir/utils.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
#include <stack>
|
||||
#include <iostream>
|
||||
#include "triton/ir/utils.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/module.h"
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
std::vector<basic_block*> cfg::reverse_post_order(function* fn) {
|
||||
std::stack<basic_block*> stack;
|
||||
std::set<basic_block*> visited;
|
||||
std::vector<basic_block*> result;
|
||||
// initialize stack
|
||||
for(ir::basic_block* block: fn->blocks())
|
||||
if(block->get_predecessors().empty())
|
||||
stack.push(block);
|
||||
// DFS
|
||||
while(!stack.empty()) {
|
||||
basic_block* current = stack.top();
|
||||
stack.pop();
|
||||
result.push_back(current);
|
||||
visited.insert(current);
|
||||
for(basic_block* succ: current->get_successors())
|
||||
if(visited.find(succ) == visited.end())
|
||||
stack.push(succ);
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
void for_each_instruction(module &mod, const std::function<void (instruction *)> &do_work) {
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
do_work(i);
|
||||
}
|
||||
|
||||
void for_each_value(module &mod, const std::function<void (value *)> &do_work) {
|
||||
std::set<ir::value*> seen;
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: cfg::reverse_post_order(fn))
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
for(ir::value *op: i->ops()){
|
||||
if(seen.insert(op).second)
|
||||
do_work(op);
|
||||
}
|
||||
if(seen.insert(i).second)
|
||||
do_work(i);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user