ugh
This commit is contained in:
@@ -2,12 +2,19 @@
|
||||
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace ir {
|
||||
class value;
|
||||
class module;
|
||||
class phi_node;
|
||||
class splat_inst;
|
||||
class reshape_inst;
|
||||
class broadcast_inst;
|
||||
class binary_operator;
|
||||
class getelementptr_inst;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
@@ -22,22 +29,47 @@ class align {
|
||||
private:
|
||||
// helpers
|
||||
bool is_first_axis_unit(ir::value *v);
|
||||
std::vector<unsigned> get_shapes(ir::value *v);
|
||||
|
||||
// populate maps
|
||||
cst_info populate_is_constant(ir::value *v);
|
||||
unsigned populate_max_contiguous(ir::value *v);
|
||||
unsigned populate_starting_multiple(ir::value *v);
|
||||
// populate is_constant
|
||||
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
|
||||
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||
// populate max_contiguous
|
||||
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
|
||||
std::vector<unsigned> populate_max_contiguous(ir::value *v);
|
||||
// populate starting_multiple
|
||||
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
|
||||
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
|
||||
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
|
||||
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
|
||||
std::vector<unsigned> populate_starting_multiple(ir::value *v);
|
||||
|
||||
public:
|
||||
void run(ir::module &mod);
|
||||
unsigned get_starting_multiple(ir::value* v) const;
|
||||
unsigned get_max_contiguous(ir::value* v) const;
|
||||
std::vector<unsigned> get_max_contiguous_vec(ir::value* v) const;
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
|
||||
private:
|
||||
std::map<ir::value*, cst_info> is_constant_;
|
||||
std::map<ir::value*, unsigned> max_contiguous_;
|
||||
std::map<ir::value*, unsigned> starting_multiple_;
|
||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
|
||||
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -11,6 +11,10 @@
|
||||
#include "triton/ir/type.h"
|
||||
#include "triton/ir/metadata.h"
|
||||
|
||||
#define _TRITON_DEFINE_CLONE(name) \
|
||||
ir::instruction* clone_impl() const { return new name(*this); }
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
@@ -25,10 +29,15 @@ class context;
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class result_reference;
|
||||
|
||||
|
||||
class instruction: public user{
|
||||
public:
|
||||
virtual std::string repr_impl() const = 0;
|
||||
|
||||
private:
|
||||
virtual ir::instruction* clone_impl() const = 0;
|
||||
|
||||
protected:
|
||||
// constructors
|
||||
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -43,19 +52,27 @@ public:
|
||||
bool has_tile_result_or_op();
|
||||
// repr
|
||||
std::string repr() const { return repr_impl(); }
|
||||
// results
|
||||
unsigned get_num_results() const { return results_.size(); }
|
||||
value* get_result(unsigned i) { return results_.at(i); }
|
||||
// metadata
|
||||
void set_metadata(ir::metadata::kind_t kind,
|
||||
unsigned value) { metadatas_[kind] = value;}
|
||||
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
|
||||
// cloning
|
||||
ir::instruction* clone() {
|
||||
ir::instruction* res = clone_impl();
|
||||
// for(auto it = op_begin(); it != op_end(); it++){
|
||||
// (*it)->add_use(res);
|
||||
// }
|
||||
res->set_name("testcloned");
|
||||
res->parent_ = nullptr;
|
||||
return res;
|
||||
}
|
||||
|
||||
private:
|
||||
basic_block *parent_;
|
||||
std::vector<value*> results_;
|
||||
std::map<ir::metadata::kind_t, unsigned> metadatas_;
|
||||
};
|
||||
|
||||
|
||||
// result reference
|
||||
class result_reference: public value {
|
||||
public:
|
||||
@@ -72,7 +89,7 @@ private:
|
||||
// phi_node classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class phi_node: public instruction{
|
||||
class phi_node: public instruction {
|
||||
private:
|
||||
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "phi"; }
|
||||
@@ -91,6 +108,8 @@ public:
|
||||
// Factory methods
|
||||
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(phi_node)
|
||||
|
||||
private:
|
||||
unsigned num_reserved_;
|
||||
std::vector<basic_block*> blocks_;
|
||||
@@ -99,7 +118,7 @@ private:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// binary_operator classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
class binary_operator: public instruction{
|
||||
class binary_operator: public instruction {
|
||||
public:
|
||||
typedef binary_op_t op_t;
|
||||
|
||||
@@ -138,6 +157,8 @@ public:
|
||||
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(binary_operator)
|
||||
|
||||
public:
|
||||
binary_op_t op_;
|
||||
bool has_no_unsigned_wrap_;
|
||||
@@ -168,20 +189,22 @@ private:
|
||||
cmp_pred_t pred_;
|
||||
};
|
||||
|
||||
class icmp_inst: public cmp_inst{
|
||||
class icmp_inst: public cmp_inst {
|
||||
using cmp_inst::cmp_inst;
|
||||
|
||||
public:
|
||||
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(icmp_inst)
|
||||
};
|
||||
|
||||
class fcmp_inst: public cmp_inst{
|
||||
class fcmp_inst: public cmp_inst {
|
||||
using cmp_inst::cmp_inst;
|
||||
|
||||
public:
|
||||
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(fcmp_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -224,7 +247,8 @@ private:
|
||||
};
|
||||
|
||||
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
|
||||
class name : public cast_inst{ \
|
||||
class name : public cast_inst { \
|
||||
_TRITON_DEFINE_CLONE(name); \
|
||||
friend class cast_inst; \
|
||||
name(type *ty, value *v, const std::string &name, instruction *next) \
|
||||
: cast_inst(ty, v, name, next, op){ } \
|
||||
@@ -253,7 +277,7 @@ class terminator_inst: public instruction{
|
||||
};
|
||||
|
||||
// return instruction
|
||||
class return_inst: public terminator_inst{
|
||||
class return_inst: public terminator_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "ret"; }
|
||||
return_inst(context &ctx, value *ret_val, instruction *next);
|
||||
@@ -267,6 +291,8 @@ public:
|
||||
|
||||
// factory methods
|
||||
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
|
||||
|
||||
_TRITON_DEFINE_CLONE(return_inst)
|
||||
};
|
||||
|
||||
// base branch instruction
|
||||
@@ -294,6 +320,7 @@ public:
|
||||
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
|
||||
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
|
||||
value *get_cond() { return get_operand(2); }
|
||||
_TRITON_DEFINE_CLONE(cond_branch_inst)
|
||||
};
|
||||
|
||||
// unconditional branch
|
||||
@@ -304,28 +331,15 @@ private:
|
||||
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
_TRITON_DEFINE_CLONE(uncond_branch_inst)
|
||||
};
|
||||
|
||||
// ternary
|
||||
class ternary_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "cond"; }
|
||||
ternary_inst(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
value *get_cond() { return get_operand(0); }
|
||||
value *get_true_value() { return get_operand(1); }
|
||||
value *get_false_value() { return get_operand(2); }
|
||||
static ternary_inst* create(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class getelementptr_inst: public instruction{
|
||||
class getelementptr_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "getelementptr"; }
|
||||
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
|
||||
@@ -345,6 +359,7 @@ public:
|
||||
// factory methods
|
||||
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(getelementptr_inst)
|
||||
|
||||
private:
|
||||
type *source_elt_ty;
|
||||
@@ -358,12 +373,16 @@ private:
|
||||
class io_inst: public instruction {
|
||||
protected:
|
||||
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
|
||||
// value *get_mask() const;
|
||||
// value *get_false_value() const;
|
||||
};
|
||||
|
||||
class load_inst: public io_inst{
|
||||
class load_inst: public io_inst {
|
||||
protected:
|
||||
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
|
||||
|
||||
@@ -372,15 +391,15 @@ private:
|
||||
static type *get_pointee_type(type *ty);
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
|
||||
// factory method
|
||||
static load_inst* create(value *ptr,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(load_inst)
|
||||
};
|
||||
|
||||
class masked_load_inst: public load_inst{
|
||||
class masked_load_inst: public load_inst {
|
||||
private:
|
||||
std::string repr_impl() const { return "masked_load"; }
|
||||
masked_load_inst(value *ptr, value *mask, value *false_value,
|
||||
@@ -394,6 +413,7 @@ public:
|
||||
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_load_inst)
|
||||
};
|
||||
|
||||
class store_inst: public io_inst{
|
||||
@@ -406,12 +426,12 @@ private:
|
||||
|
||||
public:
|
||||
// accessors
|
||||
value *get_pointer_operand() { return get_operand(0); }
|
||||
value *get_value_operand() { return get_operand(1); }
|
||||
// factory method
|
||||
static store_inst* create(value* ptr, value *v,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(store_inst)
|
||||
};
|
||||
|
||||
class masked_store_inst: public store_inst{
|
||||
@@ -427,6 +447,7 @@ public:
|
||||
static masked_store_inst* create(value *ptr, value *v, value *mask,
|
||||
const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(masked_store_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -450,6 +471,7 @@ private:
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(reshape_inst)
|
||||
};
|
||||
|
||||
// splat
|
||||
@@ -462,6 +484,7 @@ private:
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(splat_inst)
|
||||
};
|
||||
|
||||
// broadcast
|
||||
@@ -474,6 +497,7 @@ private:
|
||||
public:
|
||||
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(broadcast_inst)
|
||||
};
|
||||
|
||||
|
||||
@@ -486,6 +510,7 @@ private:
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(downcast_inst)
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -505,6 +530,7 @@ private:
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_program_id_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
@@ -518,6 +544,7 @@ private:
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
unsigned get_axis() const { return axis_; }
|
||||
_TRITON_DEFINE_CLONE(get_num_program_inst)
|
||||
|
||||
private:
|
||||
unsigned axis_;
|
||||
@@ -527,6 +554,7 @@ class atomic_cas_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "atomic_cas"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_cas_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -536,6 +564,7 @@ class atomic_exch_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_exch"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_exch_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -545,6 +574,7 @@ class atomic_add_inst: public builtin_inst {
|
||||
private:
|
||||
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
std::string repr_impl() const { return "atomic_add"; }
|
||||
_TRITON_DEFINE_CLONE(atomic_add_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -566,6 +596,7 @@ public:
|
||||
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
|
||||
bool is_a_trans() { return AT_ == Trans; }
|
||||
bool is_b_trans() { return BT_ == Trans; }
|
||||
_TRITON_DEFINE_CLONE(dot_inst)
|
||||
|
||||
private:
|
||||
TransT AT_;
|
||||
@@ -586,17 +617,12 @@ public:
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const {
|
||||
std::string res = "trans<";
|
||||
//for(ir::constant_int *x: perm_)
|
||||
// res += x->repr() + ",";
|
||||
res[res.size()-1] = '>';
|
||||
return res;
|
||||
}
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<constant_int*> get_perm() const;
|
||||
_TRITON_DEFINE_CLONE(trans_inst)
|
||||
|
||||
private:
|
||||
std::vector<constant_int*> perm_;
|
||||
@@ -608,6 +634,7 @@ private:
|
||||
std::string repr_impl() const { return "sqrt"; }
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(sqrt_inst)
|
||||
};
|
||||
|
||||
class reduce_inst: public builtin_inst {
|
||||
@@ -617,6 +644,7 @@ private:
|
||||
private:
|
||||
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "reduce"; }
|
||||
_TRITON_DEFINE_CLONE(reduce_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -630,6 +658,7 @@ class select_inst: public builtin_inst {
|
||||
private:
|
||||
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "select"; }
|
||||
_TRITON_DEFINE_CLONE(select_inst)
|
||||
|
||||
public:
|
||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -647,12 +676,14 @@ private:
|
||||
public:
|
||||
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
|
||||
instruction *next = nullptr);
|
||||
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
|
||||
};
|
||||
|
||||
class barrier_inst: public instruction{
|
||||
private:
|
||||
barrier_inst(context &ctx, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "barrier"; }
|
||||
_TRITON_DEFINE_CLONE(barrier_inst)
|
||||
|
||||
public:
|
||||
static barrier_inst* create(context &ctx, const std::string &name = "",
|
||||
@@ -663,6 +694,7 @@ class vectorize_inst: public unary_inst{
|
||||
private:
|
||||
using unary_inst::unary_inst;
|
||||
std::string repr_impl() const { return "vectorize"; }
|
||||
_TRITON_DEFINE_CLONE(vectorize_inst)
|
||||
|
||||
public:
|
||||
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -675,6 +707,7 @@ class nv_dynamic_program_idx_inst: public instruction {
|
||||
private:
|
||||
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
_TRITON_DEFINE_CLONE(nv_dynamic_program_idx_inst)
|
||||
|
||||
public:
|
||||
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
|
@@ -5,6 +5,8 @@
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
@@ -36,258 +38,448 @@ bool align::is_first_axis_unit(ir::value *x){
|
||||
return true;
|
||||
}
|
||||
|
||||
align::cst_info align::populate_is_constant(ir::value *v) {
|
||||
/*
|
||||
* is constant
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::get_shapes(ir::value *v) {
|
||||
ir::type *ty = v->get_type();
|
||||
if(ty->is_tile_ty())
|
||||
return ty->get_tile_shapes();
|
||||
else
|
||||
return {1};
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = is_constant_.find(inc);
|
||||
if(it != is_constant_.end())
|
||||
result = it->second;
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto cst = populate_is_constant(inc);
|
||||
for(size_t d = 0; d < cst.size(); d++)
|
||||
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* op = x->get_operand(0);
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(auto d: shapes)
|
||||
result.push_back(cst_info{d, op_cst[0].value});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < x_shapes.size(); d ++){
|
||||
cst_info ax ;
|
||||
if(x_shapes[d] == 1)
|
||||
ax = {1, op_cst[current].value};
|
||||
else if(!is_skewed
|
||||
&& x_shapes[d] == op_shapes[current])
|
||||
ax = {x_shapes[d], op_cst[current++].value};
|
||||
else {
|
||||
is_skewed = true;
|
||||
ax = {x_shapes[d], 0};
|
||||
}
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_cst = populate_is_constant(op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
|
||||
else
|
||||
result.push_back(op_cst[d]);
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<cst_info> result;
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||
cst_info ax;
|
||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
||||
ax = {num_constants, 0};
|
||||
}
|
||||
else
|
||||
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
|
||||
result.push_back(ax);
|
||||
}
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
auto lhs = populate_is_constant(lhs_op);
|
||||
auto rhs = populate_is_constant(rhs_op);
|
||||
std::vector<cst_info> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
|
||||
return add_to_cache(x, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
|
||||
auto shapes = get_shapes(v);
|
||||
std::vector<cst_info> result(shapes.size(), {1, 0});
|
||||
return add_to_cache(v, result, is_constant_);
|
||||
}
|
||||
|
||||
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
||||
if(is_constant_.find(v) != is_constant_.end())
|
||||
return is_constant_.at(v);
|
||||
// helper for the cache
|
||||
auto cache = [this,v](cst_info value){
|
||||
return add_to_cache(v, value, is_constant_); }
|
||||
;
|
||||
// populate
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_cst = populate_is_constant(op);
|
||||
if(is_first_axis_unit(op)){
|
||||
unsigned num_cst = x->get_type()->get_tile_shapes()[0];
|
||||
return cache({num_cst, op_cst.value});
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return cache({true, (unsigned)x->get_value()});
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
cst_info lhs = populate_is_constant(lhs_op);
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
|
||||
unsigned max_contiguous = populate_max_contiguous(lhs_op);
|
||||
// todo might not be entirely true
|
||||
unsigned num_constants = gcd(max_contiguous, rhs.value);
|
||||
return cache({num_constants, 0});
|
||||
}
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs_op = x->get_operand(0);
|
||||
ir::value* rhs_op = x->get_operand(1);
|
||||
cst_info lhs = populate_is_constant(lhs_op);
|
||||
cst_info rhs = populate_is_constant(rhs_op);
|
||||
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// cst_info value_true = populate_is_constant(x->get_value_true());
|
||||
// cst_info value_false = populate_is_constant(x->get_value_false());
|
||||
// return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
|
||||
// }
|
||||
if(v->get_type()->is_tile_ty())
|
||||
return cache({0, 0});
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(is_constant_.find(inc) != is_constant_.end())
|
||||
result = is_constant_.at(inc).num_cst;
|
||||
}
|
||||
cache({result, 0});
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = std::min(result, populate_is_constant(inc).num_cst);
|
||||
}
|
||||
return cache({result, 0});
|
||||
}
|
||||
// scalars are always constant in the contiguous dimension
|
||||
// but value is not known at compile-time
|
||||
return cache({1, 0});
|
||||
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_is_constant_phi(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_is_constant_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_is_constant_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_is_constant_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_is_constant_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_is_constant_gep(x);
|
||||
return populate_is_constant_default(v);
|
||||
}
|
||||
|
||||
unsigned align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
// helper for the cache
|
||||
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
|
||||
// populate
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(shapes[0]);
|
||||
|
||||
/*
|
||||
* max contiguous
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto it = max_contiguous_.find(inc);
|
||||
if(it != max_contiguous_.end())
|
||||
result = it->second;
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
|
||||
ir::value *op = x->get_operand(0);
|
||||
if(op->get_type()->is_tile_ty()){
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
if(op_shapes[0] == shapes[0])
|
||||
return cache(populate_max_contiguous(op));
|
||||
add_to_cache(x, result, max_contiguous_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto contiguous = populate_max_contiguous(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = std::min(result[d], contiguous[d]);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
|
||||
auto x_shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < x_shapes.size(); d++)
|
||||
result.push_back({1});
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result.push_back(op_mc[current++]);
|
||||
else {
|
||||
is_skewed = true;
|
||||
result.push_back(1);
|
||||
}
|
||||
return cache(1);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
cst_info lhs_cst_info = populate_is_constant(lhs);
|
||||
cst_info rhs_cst_info = populate_is_constant(rhs);
|
||||
if(x->is_int_rem() && rhs_cst_info.value > 0)
|
||||
return cache(std::min(lhs_max_contiguous, rhs_cst_info.value));
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result;
|
||||
ir::value *op = x->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
auto op_mc = populate_max_contiguous(op);
|
||||
for(size_t d = 0; d < shapes.size(); d++)
|
||||
if(op_shapes[d] == 1)
|
||||
result.push_back(1);
|
||||
else
|
||||
result.push_back(op_mc[d]);
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned value = 1;
|
||||
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
|
||||
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
|
||||
if(x->is_int_mult()){
|
||||
if(rhs_cst_info.value == 1)
|
||||
return cache(lhs_max_contiguous);
|
||||
if(lhs_cst_info.value == 1)
|
||||
return cache(rhs_max_contiguous);
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(rhs_cst_info[d].value == 1)
|
||||
lvalue = lhs_max_contiguous[d];
|
||||
if(lhs_cst_info[d].value == 1)
|
||||
rvalue = rhs_max_contiguous[d];
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
if(x->is_int_add_sub()){
|
||||
if(lhs_cst_info.num_cst)
|
||||
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst));
|
||||
if(rhs_cst_info.num_cst)
|
||||
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||
value = std::max(lvalue, rvalue);
|
||||
}
|
||||
result.push_back(value);
|
||||
}
|
||||
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
|
||||
// int value_true = populate_max_contiguous(x->get_value_true());
|
||||
// int value_false = populate_max_contiguous(x->get_value_false());
|
||||
// return cache(std::min(value_true, value_false));
|
||||
// }
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
if(lhs_cst_info.num_cst)
|
||||
return cache(rhs_max_contiguous);
|
||||
if(rhs_cst_info.num_cst)
|
||||
return cache(lhs_max_contiguous);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(max_contiguous_.find(inc) != max_contiguous_.end())
|
||||
result = max_contiguous_.at(inc);
|
||||
}
|
||||
cache(result);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = std::min(result, populate_max_contiguous(inc));
|
||||
}
|
||||
return cache(result);
|
||||
}
|
||||
return cache(1);
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
unsigned align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
auto cache = [this,v](unsigned value){
|
||||
return add_to_cache(v, value, starting_multiple_);
|
||||
};
|
||||
// has metadata
|
||||
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
|
||||
auto shapes = get_shapes(x);
|
||||
ir::value* lhs = x->get_operand(0);
|
||||
ir::value* rhs = x->get_operand(1);
|
||||
auto lhs_max_contiguous = populate_max_contiguous(lhs);
|
||||
auto rhs_max_contiguous = populate_max_contiguous(rhs);
|
||||
auto lhs_cst_info = populate_is_constant(lhs);
|
||||
auto rhs_cst_info = populate_is_constant(rhs);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
unsigned lvalue = 1, rvalue = 1;
|
||||
if(lhs_cst_info[d].num_cst)
|
||||
lvalue = rhs_max_contiguous[d];
|
||||
if(rhs_cst_info[d].num_cst)
|
||||
rvalue = lhs_max_contiguous[d];
|
||||
result[d] = std::max(lvalue, rvalue);
|
||||
}
|
||||
return add_to_cache(x, result, max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return add_to_cache(v, {1}, max_contiguous_);
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
if(dynamic_cast<ir::constant_range*>(v))
|
||||
return add_to_cache(v, {shapes[0]}, max_contiguous_);
|
||||
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
|
||||
if(max_contiguous_.find(v) != max_contiguous_.end())
|
||||
return max_contiguous_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_max_contiguous_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_max_contiguous_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_max_contiguous_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_max_contiguous_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_max_contiguous_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_max_contiguous_phi(x);
|
||||
return populate_max_contiguous_default(v);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* starting multiple
|
||||
*/
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
|
||||
auto shapes = get_shapes(x);
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
std::vector<unsigned> result(shapes.size(), op[0]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
|
||||
auto op = populate_starting_multiple(x->get_operand(0));
|
||||
auto op_shapes = get_shapes(x->get_operand(0));
|
||||
auto shapes = get_shapes(x);
|
||||
std::vector<unsigned> result(shapes.size(), 1);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(size_t d = 0; d < shapes.size(); d ++){
|
||||
if(shapes[d] == 1)
|
||||
result[d] = 1;
|
||||
else if(!is_skewed
|
||||
&& shapes[d] == op_shapes[current])
|
||||
result[d] = op[current++];
|
||||
else {
|
||||
is_skewed = true;
|
||||
result[d] = 1;
|
||||
}
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
|
||||
auto result = populate_starting_multiple(x->get_operand(0));
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++){
|
||||
if(x->is_int_mult())
|
||||
result[d] = lhs[d] * rhs[d];
|
||||
if(x->is_int_add_sub())
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_int_div())
|
||||
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
|
||||
if(x->is_int_rem() && rhs[d] > 1)
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
if(x->is_shl())
|
||||
result[d] = lhs[d] << rhs[d];
|
||||
if(x->is_shr())
|
||||
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
|
||||
auto lhs = populate_starting_multiple(x->get_operand(0));
|
||||
auto rhs = populate_starting_multiple(x->get_operand(1));
|
||||
std::vector<unsigned> result(lhs.size(), 1);
|
||||
for(size_t d = 0; d < lhs.size(); d++)
|
||||
result[d] = gcd(lhs[d], rhs[d]);
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
|
||||
auto shape = get_shapes(x);
|
||||
std::vector<unsigned> result(shape.size(), 1);
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
add_to_cache(x, result, starting_multiple_);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
auto sm = populate_starting_multiple(inc);
|
||||
for(size_t d = 0; d < result.size(); d++)
|
||||
result[d] = gcd(result[d], sm[d]);
|
||||
}
|
||||
return add_to_cache(x, result, starting_multiple_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
||||
ir::type* ty = v->get_type();
|
||||
if(ty->is_tile_ty()) {
|
||||
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(v)){
|
||||
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
|
||||
if(multiple_of > 0)
|
||||
return cache(multiple_of);
|
||||
return add_to_cache(x, {multiple_of}, starting_multiple_);
|
||||
}
|
||||
// arguments
|
||||
if(auto *x = dynamic_cast<ir::argument*>(v)){
|
||||
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
|
||||
for(auto attr: attributes){
|
||||
if(attr.get_kind() == ir::multiple_of)
|
||||
return cache(attr.get_value());
|
||||
if(attr.get_kind() == ir::multiple_of){
|
||||
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
|
||||
}
|
||||
if(attr.get_kind() == ir::aligned){
|
||||
ir::type* ty = x->get_type()->get_pointer_element_ty();
|
||||
int nbits = ty->get_primitive_size_in_bits();
|
||||
int nbytes = nbits / 8;
|
||||
return cache(attr.get_value() / nbytes);
|
||||
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
|
||||
int lhs = populate_starting_multiple(x->get_operand(0));
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
if(x->is_int_mult())
|
||||
return cache(lhs * rhs);
|
||||
if(x->is_int_add_sub())
|
||||
return cache(gcd(lhs, rhs));
|
||||
if(x->is_int_div())
|
||||
return cache(std::max(lhs / rhs, 1));
|
||||
if(x->is_int_rem() && rhs > 1)
|
||||
return cache(gcd(lhs, rhs));
|
||||
if(x->is_shl())
|
||||
return cache(lhs << rhs);
|
||||
if(x->is_shr())
|
||||
return cache(std::max(lhs >> rhs, 1));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v)){
|
||||
return cache(x->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(x->get_first()->get_value());
|
||||
}
|
||||
if(dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
|
||||
return cache(128);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){
|
||||
return cache(x->get_range()->get_first()->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
int lhs = populate_starting_multiple(x->get_operand(0));
|
||||
int rhs = populate_starting_multiple(x->get_operand(1));
|
||||
return cache(gcd(lhs, rhs));
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
if(shapes[0] == 1)
|
||||
return cache(1);
|
||||
else
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
int op = populate_starting_multiple(x->get_operand(0));
|
||||
return cache(op);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
|
||||
// put a conservative initial value in phi node to avoid infinite recursion
|
||||
unsigned result = 1;
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
if(starting_multiple_.find(inc) != starting_multiple_.end())
|
||||
result = starting_multiple_.at(inc);
|
||||
}
|
||||
cache(result);
|
||||
// recurse
|
||||
for(unsigned n = 0; n < x->get_num_incoming(); n++){
|
||||
ir::value* inc = x->get_incoming_value(n);
|
||||
result = gcd(result, populate_starting_multiple(inc));
|
||||
}
|
||||
return cache(result);
|
||||
}
|
||||
// scalars
|
||||
if(!v->get_type()->is_tile_ty())
|
||||
return cache(1);
|
||||
// tiles
|
||||
auto shapes = v->get_type()->get_tile_shapes();
|
||||
unsigned result = 1;
|
||||
for(unsigned i = 0; i < shapes.size() - 1; i++)
|
||||
result *= shapes[i];
|
||||
return cache(result);
|
||||
return add_to_cache(v, {1}, starting_multiple_);
|
||||
}
|
||||
|
||||
|
||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||
return starting_multiple_.at(v);
|
||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||
return populate_starting_multiple_binop(x);
|
||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::constant_range*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v))
|
||||
return add_to_cache(x, {128}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v))
|
||||
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||
return populate_starting_multiple_gep(x);
|
||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||
return populate_starting_multiple_splat(x);
|
||||
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
|
||||
return populate_starting_multiple_reshape(x);
|
||||
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
|
||||
return populate_starting_multiple_broadcast(x);
|
||||
if(auto *x = dynamic_cast<ir::phi_node*>(v))
|
||||
return populate_starting_multiple_phi(x);
|
||||
return populate_starting_multiple_default(v);
|
||||
}
|
||||
|
||||
unsigned align::get_starting_multiple(ir::value* v) const {
|
||||
return starting_multiple_.at(v);
|
||||
return starting_multiple_.at(v)[0];
|
||||
}
|
||||
|
||||
unsigned align::get_max_contiguous(ir::value* v) const {
|
||||
return max_contiguous_.at(v)[0];
|
||||
}
|
||||
|
||||
std::vector<unsigned> align::get_max_contiguous_vec(ir::value* v) const {
|
||||
return max_contiguous_.at(v);
|
||||
}
|
||||
|
||||
@@ -297,7 +489,7 @@ void align::copy(ir::value *dst, ir::value *src) {
|
||||
is_constant_[dst] = is_constant_[src];
|
||||
}
|
||||
|
||||
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
|
||||
|
||||
void align::run(ir::module &mod) {
|
||||
// populate constant
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
@@ -316,13 +508,9 @@ void align::run(ir::module &mod) {
|
||||
// populate maximum contiguous
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction *i: block->get_inst_list())
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
populate_max_contiguous(i);
|
||||
|
||||
// for(ir::function *fn: mod.get_function_list())
|
||||
// for(ir::basic_block *block: fn->blocks())
|
||||
// for(ir::instruction *i: block->get_inst_list())
|
||||
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@@ -76,16 +76,16 @@ void grids::init_c_graph(ir::instruction *v) {
|
||||
// Reshape
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
auto op_shapes = op->get_type()->get_tile_shapes();
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
bool is_one = shapes[i] == 1;
|
||||
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
|
||||
if(is_one){
|
||||
if(shapes[i] == 1){
|
||||
static_params_.insert({{v, i}, 1});
|
||||
add_constraint({v, i}, {v, i});
|
||||
}
|
||||
else if(!is_skewed && is_same)
|
||||
else if(!is_skewed &&
|
||||
shapes[i] == op_shapes[current])
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else{
|
||||
is_skewed = true;
|
||||
@@ -130,13 +130,10 @@ void grids::init_c_graph(ir::instruction *v) {
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned k = 0; k < v->get_num_results(); k++){
|
||||
ir::value *result = v->get_result(k);
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({result, i}, {op, i});
|
||||
}
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
std::vector<ir::value*> ops = v->ops();
|
||||
for(ir::value* op: ops)
|
||||
add_constraint({v, i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -864,11 +864,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
std::map<unsigned, ir::value*> references;
|
||||
create_grids(grids, references, fn);
|
||||
for(ir::value* i: grids){
|
||||
if(auto *instr = dynamic_cast<ir::instruction*>(i))
|
||||
for(unsigned r = 0; r < instr->get_num_results(); r++)
|
||||
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
|
||||
else
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
init_axes(i, builder, u_thread_warp_id, u_warp_id);
|
||||
}
|
||||
// create tile
|
||||
std::set<ir::value*> seen;
|
||||
@@ -876,8 +872,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
for(unsigned r = 0; r < i->get_num_results(); r++)
|
||||
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
|
||||
create_tile(i, builder, references, seen, sh_mem_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -20,11 +20,6 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const
|
||||
auto it = std::find(block->begin(), block->end(), next);
|
||||
block->get_inst_list().insert(it, next);
|
||||
}
|
||||
if(num_results == 1)
|
||||
results_.push_back(this);
|
||||
else
|
||||
for(unsigned i = 0; i < num_results; i++)
|
||||
results_.push_back(new result_reference(this, i));
|
||||
}
|
||||
|
||||
void instruction::erase_from_parent() {
|
||||
|
@@ -48,14 +48,8 @@ void print(module &mod, std::ostream& os) {
|
||||
os << std::endl;
|
||||
for(ir::instruction *inst: block->get_inst_list()){
|
||||
os << " ";
|
||||
unsigned num_results = inst->get_num_results();
|
||||
for(unsigned i = 0; i < num_results; i++){
|
||||
os << get_name(inst->get_result(i), cnt++);
|
||||
if(i < num_results - 1)
|
||||
os << ", ";
|
||||
else
|
||||
os << " = ";
|
||||
}
|
||||
os << get_name(inst, cnt++);
|
||||
os << " = ";
|
||||
ir::type* type = inst->get_type();
|
||||
os << inst->repr() << " " << type->repr();
|
||||
ir::instruction::ops_t ops = inst->ops();
|
||||
|
@@ -5,6 +5,7 @@
|
||||
#include <algorithm>
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
#include "triton/lang/cpp.h"
|
||||
#include "triton/lang/parser.h"
|
||||
#include "triton/lang/code_gen.h"
|
||||
@@ -198,6 +199,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::analysis::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
|
||||
codegen::analysis::align alignment_info;
|
||||
codegen::transform::reorder reorder(&alignment_info);
|
||||
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&grids);
|
||||
codegen::transform::dce dce;
|
||||
@@ -208,6 +210,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
alignment_info.run(module);
|
||||
ir::print(module, std::cout);
|
||||
// reorder.run(module);
|
||||
dce.run(module);
|
||||
ir::print(module, std::cout);
|
||||
grids.run(module);
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
|
@@ -38,7 +38,7 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
|
||||
int rm[TM] = ridm * TM + 0 ... TM;
|
||||
int rn[TN] = ridn * TN + 0 ... TN;
|
||||
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy;
|
||||
TYPE* py[TM, TN] = Y + rm[:, newaxis] * ldy + rn[newaxis, :];
|
||||
*py = *px;
|
||||
}
|
||||
)";
|
||||
|
Reference in New Issue
Block a user