This commit is contained in:
Philippe Tillet
2019-09-08 17:35:24 -04:00
parent 0ff81badac
commit 32234c2612
10 changed files with 541 additions and 301 deletions

View File

@@ -2,12 +2,19 @@
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H #define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map> #include <map>
#include <vector>
namespace triton { namespace triton {
namespace ir { namespace ir {
class value; class value;
class module; class module;
class phi_node;
class splat_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
} }
namespace codegen{ namespace codegen{
@@ -22,22 +29,47 @@ class align {
private: private:
// helpers // helpers
bool is_first_axis_unit(ir::value *v); bool is_first_axis_unit(ir::value *v);
std::vector<unsigned> get_shapes(ir::value *v);
// populate maps // populate is_constant
cst_info populate_is_constant(ir::value *v); std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
unsigned populate_max_contiguous(ir::value *v); std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
unsigned populate_starting_multiple(ir::value *v); std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
public: public:
void run(ir::module &mod); void run(ir::module &mod);
unsigned get_starting_multiple(ir::value* v) const; unsigned get_starting_multiple(ir::value* v) const;
unsigned get_max_contiguous(ir::value* v) const; unsigned get_max_contiguous(ir::value* v) const;
std::vector<unsigned> get_max_contiguous_vec(ir::value* v) const;
void copy(ir::value *dst, ir::value *src); void copy(ir::value *dst, ir::value *src);
private: private:
std::map<ir::value*, cst_info> is_constant_; std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, unsigned> max_contiguous_; std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, unsigned> starting_multiple_; std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
}; };

View File

@@ -11,6 +11,10 @@
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include "triton/ir/metadata.h" #include "triton/ir/metadata.h"
#define _TRITON_DEFINE_CLONE(name) \
ir::instruction* clone_impl() const { return new name(*this); }
namespace triton{ namespace triton{
namespace ir{ namespace ir{
@@ -25,10 +29,15 @@ class context;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class result_reference; class result_reference;
class instruction: public user{ class instruction: public user{
public: public:
virtual std::string repr_impl() const = 0; virtual std::string repr_impl() const = 0;
private:
virtual ir::instruction* clone_impl() const = 0;
protected: protected:
// constructors // constructors
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr); instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
@@ -43,19 +52,27 @@ public:
bool has_tile_result_or_op(); bool has_tile_result_or_op();
// repr // repr
std::string repr() const { return repr_impl(); } std::string repr() const { return repr_impl(); }
// results
unsigned get_num_results() const { return results_.size(); }
value* get_result(unsigned i) { return results_.at(i); }
// metadata // metadata
void set_metadata(ir::metadata::kind_t kind, void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;} unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
// for(auto it = op_begin(); it != op_end(); it++){
// (*it)->add_use(res);
// }
res->set_name("testcloned");
res->parent_ = nullptr;
return res;
}
private: private:
basic_block *parent_; basic_block *parent_;
std::vector<value*> results_;
std::map<ir::metadata::kind_t, unsigned> metadatas_; std::map<ir::metadata::kind_t, unsigned> metadatas_;
}; };
// result reference // result reference
class result_reference: public value { class result_reference: public value {
public: public:
@@ -72,7 +89,7 @@ private:
// phi_node classes // phi_node classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class phi_node: public instruction{ class phi_node: public instruction {
private: private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next); phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; } std::string repr_impl() const { return "phi"; }
@@ -91,6 +108,8 @@ public:
// Factory methods // Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr); static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(phi_node)
private: private:
unsigned num_reserved_; unsigned num_reserved_;
std::vector<basic_block*> blocks_; std::vector<basic_block*> blocks_;
@@ -99,7 +118,7 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// binary_operator classes // binary_operator classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class binary_operator: public instruction{ class binary_operator: public instruction {
public: public:
typedef binary_op_t op_t; typedef binary_op_t op_t;
@@ -138,6 +157,8 @@ public:
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr); static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr); static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
public: public:
binary_op_t op_; binary_op_t op_;
bool has_no_unsigned_wrap_; bool has_no_unsigned_wrap_;
@@ -168,20 +189,22 @@ private:
cmp_pred_t pred_; cmp_pred_t pred_;
}; };
class icmp_inst: public cmp_inst{ class icmp_inst: public cmp_inst {
using cmp_inst::cmp_inst; using cmp_inst::cmp_inst;
public: public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(icmp_inst)
}; };
class fcmp_inst: public cmp_inst{ class fcmp_inst: public cmp_inst {
using cmp_inst::cmp_inst; using cmp_inst::cmp_inst;
public: public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(fcmp_inst)
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -224,7 +247,8 @@ private:
}; };
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \ #define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
class name : public cast_inst{ \ class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name); \
friend class cast_inst; \ friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \ name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, v, name, next, op){ } \ : cast_inst(ty, v, name, next, op){ } \
@@ -253,7 +277,7 @@ class terminator_inst: public instruction{
}; };
// return instruction // return instruction
class return_inst: public terminator_inst{ class return_inst: public terminator_inst {
private: private:
std::string repr_impl() const { return "ret"; } std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next); return_inst(context &ctx, value *ret_val, instruction *next);
@@ -267,6 +291,8 @@ public:
// factory methods // factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr); static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
_TRITON_DEFINE_CLONE(return_inst)
}; };
// base branch instruction // base branch instruction
@@ -294,6 +320,7 @@ public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); } basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
basic_block *get_false_dest() { return (basic_block*)get_operand(1); } basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
value *get_cond() { return get_operand(2); } value *get_cond() { return get_operand(2); }
_TRITON_DEFINE_CLONE(cond_branch_inst)
}; };
// unconditional branch // unconditional branch
@@ -304,28 +331,15 @@ private:
public: public:
basic_block *get_dest() { return (basic_block*)get_operand(0); } basic_block *get_dest() { return (basic_block*)get_operand(0); }
_TRITON_DEFINE_CLONE(uncond_branch_inst)
}; };
// ternary
class ternary_inst: public instruction {
private:
std::string repr_impl() const { return "cond"; }
ternary_inst(value *cond, value *true_value, value *false_value,
const std::string &name, instruction *next);
public:
value *get_cond() { return get_operand(0); }
value *get_true_value() { return get_operand(1); }
value *get_false_value() { return get_operand(2); }
static ternary_inst* create(value *cond, value *true_value, value *false_value,
const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// getelementptr_inst classes // getelementptr_inst classes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction{ class getelementptr_inst: public instruction {
private: private:
std::string repr_impl() const { return "getelementptr"; } std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next); getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
@@ -345,6 +359,7 @@ public:
// factory methods // factory methods
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx, static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(getelementptr_inst)
private: private:
type *source_elt_ty; type *source_elt_ty;
@@ -358,12 +373,16 @@ private:
class io_inst: public instruction { class io_inst: public instruction {
protected: protected:
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr); io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
public: public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
// value *get_mask() const; // value *get_mask() const;
// value *get_false_value() const; // value *get_false_value() const;
}; };
class load_inst: public io_inst{ class load_inst: public io_inst {
protected: protected:
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next); load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
@@ -372,15 +391,15 @@ private:
static type *get_pointee_type(type *ty); static type *get_pointee_type(type *ty);
public: public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
// factory method // factory method
static load_inst* create(value *ptr, static load_inst* create(value *ptr,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(load_inst)
}; };
class masked_load_inst: public load_inst{ class masked_load_inst: public load_inst {
private: private:
std::string repr_impl() const { return "masked_load"; } std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value, masked_load_inst(value *ptr, value *mask, value *false_value,
@@ -394,6 +413,7 @@ public:
static masked_load_inst* create(value *ptr, value *mask, value *false_value, static masked_load_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
}; };
class store_inst: public io_inst{ class store_inst: public io_inst{
@@ -406,12 +426,12 @@ private:
public: public:
// accessors // accessors
value *get_pointer_operand() { return get_operand(0); }
value *get_value_operand() { return get_operand(1); } value *get_value_operand() { return get_operand(1); }
// factory method // factory method
static store_inst* create(value* ptr, value *v, static store_inst* create(value* ptr, value *v,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(store_inst)
}; };
class masked_store_inst: public store_inst{ class masked_store_inst: public store_inst{
@@ -427,6 +447,7 @@ public:
static masked_store_inst* create(value *ptr, value *v, value *mask, static masked_store_inst* create(value *ptr, value *v, value *mask,
const std::string &name = "", const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -450,6 +471,7 @@ private:
public: public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
}; };
// splat // splat
@@ -462,6 +484,7 @@ private:
public: public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
}; };
// broadcast // broadcast
@@ -474,6 +497,7 @@ private:
public: public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix, static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr); const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
}; };
@@ -486,6 +510,7 @@ private:
public: public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(downcast_inst)
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -505,6 +530,7 @@ private:
public: public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; } unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_program_id_inst)
private: private:
unsigned axis_; unsigned axis_;
@@ -518,6 +544,7 @@ private:
public: public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; } unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_program_inst)
private: private:
unsigned axis_; unsigned axis_;
@@ -527,6 +554,7 @@ class atomic_cas_inst: public builtin_inst {
private: private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; } std::string repr_impl() const { return "atomic_cas"; }
_TRITON_DEFINE_CLONE(atomic_cas_inst)
public: public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -536,6 +564,7 @@ class atomic_exch_inst: public builtin_inst {
private: private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; } std::string repr_impl() const { return "atomic_exch"; }
_TRITON_DEFINE_CLONE(atomic_exch_inst)
public: public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -545,6 +574,7 @@ class atomic_add_inst: public builtin_inst {
private: private:
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; } std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
public: public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -566,6 +596,7 @@ public:
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
bool is_a_trans() { return AT_ == Trans; } bool is_a_trans() { return AT_ == Trans; }
bool is_b_trans() { return BT_ == Trans; } bool is_b_trans() { return BT_ == Trans; }
_TRITON_DEFINE_CLONE(dot_inst)
private: private:
TransT AT_; TransT AT_;
@@ -586,17 +617,12 @@ public:
private: private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next); trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { std::string repr_impl() const { return "trans"; }
std::string res = "trans<";
//for(ir::constant_int *x: perm_)
// res += x->repr() + ",";
res[res.size()-1] = '>';
return res;
}
public: public:
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<constant_int*> get_perm() const; const std::vector<constant_int*> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst)
private: private:
std::vector<constant_int*> perm_; std::vector<constant_int*> perm_;
@@ -608,6 +634,7 @@ private:
std::string repr_impl() const { return "sqrt"; } std::string repr_impl() const { return "sqrt"; }
public: public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(sqrt_inst)
}; };
class reduce_inst: public builtin_inst { class reduce_inst: public builtin_inst {
@@ -617,6 +644,7 @@ private:
private: private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; } std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
public: public:
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
@@ -630,6 +658,7 @@ class select_inst: public builtin_inst {
private: private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; } std::string repr_impl() const { return "select"; }
_TRITON_DEFINE_CLONE(select_inst)
public: public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr); static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
@@ -647,12 +676,14 @@ private:
public: public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "", static copy_to_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr); instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
}; };
class barrier_inst: public instruction{ class barrier_inst: public instruction{
private: private:
barrier_inst(context &ctx, const std::string &name, instruction *next); barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; } std::string repr_impl() const { return "barrier"; }
_TRITON_DEFINE_CLONE(barrier_inst)
public: public:
static barrier_inst* create(context &ctx, const std::string &name = "", static barrier_inst* create(context &ctx, const std::string &name = "",
@@ -663,6 +694,7 @@ class vectorize_inst: public unary_inst{
private: private:
using unary_inst::unary_inst; using unary_inst::unary_inst;
std::string repr_impl() const { return "vectorize"; } std::string repr_impl() const { return "vectorize"; }
_TRITON_DEFINE_CLONE(vectorize_inst)
public: public:
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
@@ -675,6 +707,7 @@ class nv_dynamic_program_idx_inst: public instruction {
private: private:
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next); nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_program_idx"; } std::string repr_impl() const { return "nv_dynamic_program_idx"; }
_TRITON_DEFINE_CLONE(nv_dynamic_program_idx_inst)
public: public:
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr); static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);

View File

@@ -5,6 +5,8 @@
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
#include "triton/ir/type.h" #include "triton/ir/type.h"
#include <iostream> #include <iostream>
#include <numeric>
#include <algorithm>
namespace triton { namespace triton {
namespace codegen{ namespace codegen{
@@ -36,258 +38,448 @@ bool align::is_first_axis_unit(ir::value *x){
return true; return true;
} }
align::cst_info align::populate_is_constant(ir::value *v) { /*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* op = x->get_operand(0);
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end()) if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v); return is_constant_.at(v);
// helper for the cache
auto cache = [this,v](cst_info value){
return add_to_cache(v, value, is_constant_); }
;
// populate
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
ir::value *op = x->get_operand(0);
auto op_cst = populate_is_constant(op);
if(is_first_axis_unit(op)){
unsigned num_cst = x->get_type()->get_tile_shapes()[0];
return cache({num_cst, op_cst.value});
}
}
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v))
return cache({true, (unsigned)x->get_value()}); return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v))
ir::value* lhs_op = x->get_operand(0); return populate_is_constant_phi(x);
ir::value* rhs_op = x->get_operand(1); if(auto *x = dynamic_cast<ir::splat_inst*>(v))
cst_info lhs = populate_is_constant(lhs_op); return populate_is_constant_splat(x);
cst_info rhs = populate_is_constant(rhs_op); if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){ return populate_is_constant_reshape(x);
unsigned max_contiguous = populate_max_contiguous(lhs_op); if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
// todo might not be entirely true return populate_is_constant_broadcast(x);
unsigned num_constants = gcd(max_contiguous, rhs.value); if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return cache({num_constants, 0}); return populate_is_constant_binop(x);
} if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return cache({std::min(lhs.num_cst, rhs.num_cst), 0}); return populate_is_constant_gep(x);
} return populate_is_constant_default(v);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
cst_info lhs = populate_is_constant(lhs_op);
cst_info rhs = populate_is_constant(rhs_op);
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
}
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
// cst_info value_true = populate_is_constant(x->get_value_true());
// cst_info value_false = populate_is_constant(x->get_value_false());
// return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
// }
if(v->get_type()->is_tile_ty())
return cache({0, 0});
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(is_constant_.find(inc) != is_constant_.end())
result = is_constant_.at(inc).num_cst;
}
cache({result, 0});
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = std::min(result, populate_is_constant(inc).num_cst);
}
return cache({result, 0});
}
// scalars are always constant in the contiguous dimension
// but value is not known at compile-time
return cache({1, 0});
} }
unsigned align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end()) /*
return max_contiguous_.at(v); * max contiguous
// helper for the cache */
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
// populate std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
if(!v->get_type()->is_tile_ty()) auto shapes = get_shapes(x);
return cache(1); std::vector<unsigned> result(shapes.size(), 1);
auto shapes = v->get_type()->get_tile_shapes(); for(unsigned n = 0; n < x->get_num_incoming(); n++){
if(dynamic_cast<ir::constant_range*>(v)){ ir::value* inc = x->get_incoming_value(n);
return cache(shapes[0]); auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
} }
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){ add_to_cache(x, result, max_contiguous_);
ir::value *op = x->get_operand(0); // recurse
if(op->get_type()->is_tile_ty()){ for(unsigned n = 0; n < x->get_num_incoming(); n++){
auto op_shapes = op->get_type()->get_tile_shapes(); ir::value* inc = x->get_incoming_value(n);
if(op_shapes[0] == shapes[0]) auto contiguous = populate_max_contiguous(inc);
return cache(populate_max_contiguous(op)); for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
} }
return cache(1);
} }
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){ return add_to_cache(x, result, max_contiguous_);
ir::value* lhs = x->get_operand(0); }
ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs); std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
unsigned rhs_max_contiguous = populate_max_contiguous(rhs); auto shapes = get_shapes(x);
cst_info lhs_cst_info = populate_is_constant(lhs); std::vector<unsigned> result;
cst_info rhs_cst_info = populate_is_constant(rhs); ir::value *op = x->get_operand(0);
if(x->is_int_rem() && rhs_cst_info.value > 0) auto op_shapes = op->get_type()->get_tile_shapes();
return cache(std::min(lhs_max_contiguous, rhs_cst_info.value)); auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
if(x->is_int_mult()){ if(x->is_int_mult()){
if(rhs_cst_info.value == 1) unsigned lvalue = 1, rvalue = 1;
return cache(lhs_max_contiguous); if(rhs_cst_info[d].value == 1)
if(lhs_cst_info.value == 1) lvalue = lhs_max_contiguous[d];
return cache(rhs_max_contiguous); if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
} }
if(x->is_int_add_sub()){ if(x->is_int_add_sub()){
if(lhs_cst_info.num_cst) unsigned lvalue = 1, rvalue = 1;
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst)); if(lhs_cst_info[d].num_cst)
if(rhs_cst_info.num_cst) lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst)); if(rhs_cst_info[d].num_cst)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
} }
result.push_back(value);
} }
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){ return add_to_cache(x, result, max_contiguous_);
// int value_true = populate_max_contiguous(x->get_value_true());
// int value_false = populate_max_contiguous(x->get_value_false());
// return cache(std::min(value_true, value_false));
// }
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
if(lhs_cst_info.num_cst)
return cache(rhs_max_contiguous);
if(rhs_cst_info.num_cst)
return cache(lhs_max_contiguous);
}
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(max_contiguous_.find(inc) != max_contiguous_.end())
result = max_contiguous_.at(inc);
}
cache(result);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = std::min(result, populate_max_contiguous(inc));
}
return cache(result);
}
return cache(1);
} }
unsigned align::populate_starting_multiple(ir::value *v){ std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
if(starting_multiple_.find(v) != starting_multiple_.end()) auto shapes = get_shapes(x);
return starting_multiple_.at(v); ir::value* lhs = x->get_operand(0);
auto cache = [this,v](unsigned value){ ir::value* rhs = x->get_operand(1);
return add_to_cache(v, value, starting_multiple_); auto lhs_max_contiguous = populate_max_contiguous(lhs);
}; auto rhs_max_contiguous = populate_max_contiguous(rhs);
// has metadata auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::constant_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
if(x->is_int_rem() && rhs[d] > 1)
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_shl())
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++)
result[d] = gcd(lhs[d], rhs[d]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
}
if(auto *x = dynamic_cast<ir::instruction*>(v)){ if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of); unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0) if(multiple_of > 0)
return cache(multiple_of); return add_to_cache(x, {multiple_of}, starting_multiple_);
} }
// arguments
if(auto *x = dynamic_cast<ir::argument*>(v)){ if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x); std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){ for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of) if(attr.get_kind() == ir::multiple_of){
return cache(attr.get_value()); return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){ if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty(); ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits(); int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8; int nbytes = nbits / 8;
return cache(attr.get_value() / nbytes); return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
} }
} }
} }
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){ return add_to_cache(v, {1}, starting_multiple_);
int lhs = populate_starting_multiple(x->get_operand(0)); }
int rhs = populate_starting_multiple(x->get_operand(1));
if(x->is_int_mult())
return cache(lhs * rhs); std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(x->is_int_add_sub()) if(starting_multiple_.find(v) != starting_multiple_.end())
return cache(gcd(lhs, rhs)); return starting_multiple_.at(v);
if(x->is_int_div()) if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return cache(std::max(lhs / rhs, 1)); return populate_starting_multiple_binop(x);
if(x->is_int_rem() && rhs > 1) if(auto *x = dynamic_cast<ir::constant_int*>(v))
return cache(gcd(lhs, rhs)); return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
if(x->is_shl()) if(auto *x = dynamic_cast<ir::constant_range*>(v))
return cache(lhs << rhs); return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(x->is_shr()) if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v))
return cache(std::max(lhs >> rhs, 1)); return add_to_cache(x, {128}, starting_multiple_);
} if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v))
if(auto *x = dynamic_cast<ir::constant_int*>(v)){ return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
return cache(x->get_value()); if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
} return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::constant_range*>(v)){ if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return cache(x->get_first()->get_value()); return populate_starting_multiple_splat(x);
} if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
if(dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){ return populate_starting_multiple_reshape(x);
return cache(128); if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
} return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){ if(auto *x = dynamic_cast<ir::phi_node*>(v))
return cache(x->get_range()->get_first()->get_value()); return populate_starting_multiple_phi(x);
} return populate_starting_multiple_default(v);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
int lhs = populate_starting_multiple(x->get_operand(0));
int rhs = populate_starting_multiple(x->get_operand(1));
return cache(gcd(lhs, rhs));
}
if(auto *x = dynamic_cast<ir::splat_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
auto shapes = x->get_type()->get_tile_shapes();
if(shapes[0] == 1)
return cache(1);
else
return cache(op);
}
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
cache(result);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = gcd(result, populate_starting_multiple(inc));
}
return cache(result);
}
// scalars
if(!v->get_type()->is_tile_ty())
return cache(1);
// tiles
auto shapes = v->get_type()->get_tile_shapes();
unsigned result = 1;
for(unsigned i = 0; i < shapes.size() - 1; i++)
result *= shapes[i];
return cache(result);
} }
unsigned align::get_starting_multiple(ir::value* v) const { unsigned align::get_starting_multiple(ir::value* v) const {
return starting_multiple_.at(v); return starting_multiple_.at(v)[0];
} }
unsigned align::get_max_contiguous(ir::value* v) const { unsigned align::get_max_contiguous(ir::value* v) const {
return max_contiguous_.at(v)[0];
}
std::vector<unsigned> align::get_max_contiguous_vec(ir::value* v) const {
return max_contiguous_.at(v); return max_contiguous_.at(v);
} }
@@ -297,7 +489,7 @@ void align::copy(ir::value *dst, ir::value *src) {
is_constant_[dst] = is_constant_[src]; is_constant_[dst] = is_constant_[src];
} }
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
void align::run(ir::module &mod) { void align::run(ir::module &mod) {
// populate constant // populate constant
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
@@ -316,13 +508,9 @@ void align::run(ir::module &mod) {
// populate maximum contiguous // populate maximum contiguous
for(ir::function *fn: mod.get_function_list()) for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks()) for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()) for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i); populate_max_contiguous(i);
}
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction *i: block->get_inst_list())
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl;
} }

View File

@@ -76,16 +76,16 @@ void grids::init_c_graph(ir::instruction *v) {
// Reshape // Reshape
if(dynamic_cast<ir::reshape_inst*>(v)) { if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
unsigned current = 0; unsigned current = 0;
bool is_skewed = false; bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){ for(unsigned i = 0; i < shapes.size(); i ++){
bool is_one = shapes[i] == 1; if(shapes[i] == 1){
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
if(is_one){
static_params_.insert({{v, i}, 1}); static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i}); add_constraint({v, i}, {v, i});
} }
else if(!is_skewed && is_same) else if(!is_skewed &&
shapes[i] == op_shapes[current])
add_constraint({v, i}, {op, current++}); add_constraint({v, i}, {op, current++});
else{ else{
is_skewed = true; is_skewed = true;
@@ -130,13 +130,10 @@ void grids::init_c_graph(ir::instruction *v) {
} }
// Element-wise // Element-wise
else if(dynamic_cast<ir::user*>(v)) { else if(dynamic_cast<ir::user*>(v)) {
for(unsigned k = 0; k < v->get_num_results(); k++){ for(unsigned i = 0; i < shapes.size(); i ++){
ir::value *result = v->get_result(k); std::vector<ir::value*> ops = v->ops();
for(unsigned i = 0; i < shapes.size(); i ++){ for(ir::value* op: ops)
std::vector<ir::value*> ops = v->ops(); add_constraint({v, i}, {op, i});
for(ir::value* op: ops)
add_constraint({result, i}, {op, i});
}
} }
} }
} }

View File

@@ -864,11 +864,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
std::map<unsigned, ir::value*> references; std::map<unsigned, ir::value*> references;
create_grids(grids, references, fn); create_grids(grids, references, fn);
for(ir::value* i: grids){ for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i)) init_axes(i, builder, u_thread_warp_id, u_warp_id);
for(unsigned r = 0; r < instr->get_num_results(); r++)
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
else
init_axes(i, builder, u_thread_warp_id, u_warp_id);
} }
// create tile // create tile
std::set<ir::value*> seen; std::set<ir::value*> seen;
@@ -876,8 +872,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
for(ir::instruction *i: block->get_inst_list()){ for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty()) if(!i->get_type()->is_tile_ty())
continue; continue;
for(unsigned r = 0; r < i->get_num_results(); r++) create_tile(i, builder, references, seen, sh_mem_ptr);
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
} }
} }

View File

@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl; std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -20,11 +20,6 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const
auto it = std::find(block->begin(), block->end(), next); auto it = std::find(block->begin(), block->end(), next);
block->get_inst_list().insert(it, next); block->get_inst_list().insert(it, next);
} }
if(num_results == 1)
results_.push_back(this);
else
for(unsigned i = 0; i < num_results; i++)
results_.push_back(new result_reference(this, i));
} }
void instruction::erase_from_parent() { void instruction::erase_from_parent() {

View File

@@ -48,14 +48,8 @@ void print(module &mod, std::ostream& os) {
os << std::endl; os << std::endl;
for(ir::instruction *inst: block->get_inst_list()){ for(ir::instruction *inst: block->get_inst_list()){
os << " "; os << " ";
unsigned num_results = inst->get_num_results(); os << get_name(inst, cnt++);
for(unsigned i = 0; i < num_results; i++){ os << " = ";
os << get_name(inst->get_result(i), cnt++);
if(i < num_results - 1)
os << ", ";
else
os << " = ";
}
ir::type* type = inst->get_type(); ir::type* type = inst->get_type();
os << inst->repr() << " " << type->repr(); os << inst->repr() << " " << type->repr();
ir::instruction::ops_t ops = inst->ops(); ir::instruction::ops_t ops = inst->ops();

View File

@@ -5,6 +5,7 @@
#include <algorithm> #include <algorithm>
#include "triton/codegen/selection.h" #include "triton/codegen/selection.h"
#include "triton/runtime/function.h" #include "triton/runtime/function.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/lang/cpp.h" #include "triton/lang/cpp.h"
#include "triton/lang/parser.h" #include "triton/lang/parser.h"
#include "triton/lang/code_gen.h" #include "triton/lang/code_gen.h"
@@ -198,6 +199,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::liveness shmem_liveness(&shmem_info); codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::align alignment_info; codegen::analysis::align alignment_info;
codegen::transform::reorder reorder(&alignment_info);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&grids); codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce; codegen::transform::dce dce;
@@ -208,6 +210,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);
alignment_info.run(module); alignment_info.run(module);
ir::print(module, std::cout);
// reorder.run(module);
dce.run(module);
ir::print(module, std::cout);
grids.run(module); grids.run(module);
reassociate.run(module); reassociate.run(module);
dce.run(module); dce.run(module);

View File

@@ -38,7 +38,7 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
int rm[TM] = ridm * TM + 0 ... TM; int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN; int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy; TYPE* py[TM, TN] = Y + rm[:, newaxis] * ldy + rn[newaxis, :];
*py = *px; *py = *px;
} }
)"; )";