This commit is contained in:
Philippe Tillet
2019-09-08 17:35:24 -04:00
parent 0ff81badac
commit 32234c2612
10 changed files with 541 additions and 301 deletions

View File

@@ -2,12 +2,19 @@
#define TDL_INCLUDE_CODEGEN_ALIGNMENT_INFO_PASS_H
#include <map>
#include <vector>
namespace triton {
namespace ir {
class value;
class module;
class phi_node;
class splat_inst;
class reshape_inst;
class broadcast_inst;
class binary_operator;
class getelementptr_inst;
}
namespace codegen{
@@ -22,22 +29,47 @@ class align {
private:
// helpers
bool is_first_axis_unit(ir::value *v);
std::vector<unsigned> get_shapes(ir::value *v);
// populate maps
cst_info populate_is_constant(ir::value *v);
unsigned populate_max_contiguous(ir::value *v);
unsigned populate_starting_multiple(ir::value *v);
// populate is_constant
std::vector<cst_info> populate_is_constant_phi(ir::phi_node* x);
std::vector<cst_info> populate_is_constant_splat(ir::splat_inst* x);
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
std::vector<cst_info> populate_is_constant_default(ir::value* v);
std::vector<cst_info> populate_is_constant(ir::value *v);
// populate max_contiguous
std::vector<unsigned> populate_max_contiguous_phi(ir::phi_node* x);
std::vector<unsigned> populate_max_contiguous_splat(ir::splat_inst* x);
std::vector<unsigned> populate_max_contiguous_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_max_contiguous_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_max_contiguous_binop(ir::binary_operator* x);
std::vector<unsigned> populate_max_contiguous_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_max_contiguous_default(ir::value* v);
std::vector<unsigned> populate_max_contiguous(ir::value *v);
// populate starting_multiple
std::vector<unsigned> populate_starting_multiple_phi(ir::phi_node* x);
std::vector<unsigned> populate_starting_multiple_splat(ir::splat_inst* x);
std::vector<unsigned> populate_starting_multiple_reshape(ir::reshape_inst* x);
std::vector<unsigned> populate_starting_multiple_broadcast(ir::broadcast_inst* x);
std::vector<unsigned> populate_starting_multiple_binop(ir::binary_operator* x);
std::vector<unsigned> populate_starting_multiple_gep(ir::getelementptr_inst* x);
std::vector<unsigned> populate_starting_multiple_default(ir::value* v);
std::vector<unsigned> populate_starting_multiple(ir::value *v);
public:
void run(ir::module &mod);
unsigned get_starting_multiple(ir::value* v) const;
unsigned get_max_contiguous(ir::value* v) const;
std::vector<unsigned> get_max_contiguous_vec(ir::value* v) const;
void copy(ir::value *dst, ir::value *src);
private:
std::map<ir::value*, cst_info> is_constant_;
std::map<ir::value*, unsigned> max_contiguous_;
std::map<ir::value*, unsigned> starting_multiple_;
std::map<ir::value*, std::vector<cst_info>> is_constant_;
std::map<ir::value*, std::vector<unsigned>> max_contiguous_;
std::map<ir::value*, std::vector<unsigned>> starting_multiple_;
};

View File

@@ -11,6 +11,10 @@
#include "triton/ir/type.h"
#include "triton/ir/metadata.h"
#define _TRITON_DEFINE_CLONE(name) \
ir::instruction* clone_impl() const { return new name(*this); }
namespace triton{
namespace ir{
@@ -25,10 +29,15 @@ class context;
//===----------------------------------------------------------------------===//
class result_reference;
class instruction: public user{
public:
virtual std::string repr_impl() const = 0;
private:
virtual ir::instruction* clone_impl() const = 0;
protected:
// constructors
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
@@ -43,19 +52,27 @@ public:
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
// results
unsigned get_num_results() const { return results_.size(); }
value* get_result(unsigned i) { return results_.at(i); }
// metadata
void set_metadata(ir::metadata::kind_t kind,
unsigned value) { metadatas_[kind] = value;}
unsigned get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
// cloning
ir::instruction* clone() {
ir::instruction* res = clone_impl();
// for(auto it = op_begin(); it != op_end(); it++){
// (*it)->add_use(res);
// }
res->set_name("testcloned");
res->parent_ = nullptr;
return res;
}
private:
basic_block *parent_;
std::vector<value*> results_;
std::map<ir::metadata::kind_t, unsigned> metadatas_;
};
// result reference
class result_reference: public value {
public:
@@ -72,7 +89,7 @@ private:
// phi_node classes
//===----------------------------------------------------------------------===//
class phi_node: public instruction{
class phi_node: public instruction {
private:
phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
std::string repr_impl() const { return "phi"; }
@@ -91,6 +108,8 @@ public:
// Factory methods
static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(phi_node)
private:
unsigned num_reserved_;
std::vector<basic_block*> blocks_;
@@ -99,7 +118,7 @@ private:
//===----------------------------------------------------------------------===//
// binary_operator classes
//===----------------------------------------------------------------------===//
class binary_operator: public instruction{
class binary_operator: public instruction {
public:
typedef binary_op_t op_t;
@@ -138,6 +157,8 @@ public:
static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(binary_operator)
public:
binary_op_t op_;
bool has_no_unsigned_wrap_;
@@ -168,20 +189,22 @@ private:
cmp_pred_t pred_;
};
class icmp_inst: public cmp_inst{
class icmp_inst: public cmp_inst {
using cmp_inst::cmp_inst;
public:
static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(icmp_inst)
};
class fcmp_inst: public cmp_inst{
class fcmp_inst: public cmp_inst {
using cmp_inst::cmp_inst;
public:
static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(fcmp_inst)
};
//===----------------------------------------------------------------------===//
@@ -224,7 +247,8 @@ private:
};
#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, op) \
class name : public cast_inst{ \
class name : public cast_inst { \
_TRITON_DEFINE_CLONE(name); \
friend class cast_inst; \
name(type *ty, value *v, const std::string &name, instruction *next) \
: cast_inst(ty, v, name, next, op){ } \
@@ -253,7 +277,7 @@ class terminator_inst: public instruction{
};
// return instruction
class return_inst: public terminator_inst{
class return_inst: public terminator_inst {
private:
std::string repr_impl() const { return "ret"; }
return_inst(context &ctx, value *ret_val, instruction *next);
@@ -267,6 +291,8 @@ public:
// factory methods
static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
_TRITON_DEFINE_CLONE(return_inst)
};
// base branch instruction
@@ -294,6 +320,7 @@ public:
basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
value *get_cond() { return get_operand(2); }
_TRITON_DEFINE_CLONE(cond_branch_inst)
};
// unconditional branch
@@ -304,28 +331,15 @@ private:
public:
basic_block *get_dest() { return (basic_block*)get_operand(0); }
_TRITON_DEFINE_CLONE(uncond_branch_inst)
};
// ternary
class ternary_inst: public instruction {
private:
std::string repr_impl() const { return "cond"; }
ternary_inst(value *cond, value *true_value, value *false_value,
const std::string &name, instruction *next);
public:
value *get_cond() { return get_operand(0); }
value *get_true_value() { return get_operand(1); }
value *get_false_value() { return get_operand(2); }
static ternary_inst* create(value *cond, value *true_value, value *false_value,
const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
class getelementptr_inst: public instruction{
class getelementptr_inst: public instruction {
private:
std::string repr_impl() const { return "getelementptr"; }
getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
@@ -345,6 +359,7 @@ public:
// factory methods
static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(getelementptr_inst)
private:
type *source_elt_ty;
@@ -358,12 +373,16 @@ private:
class io_inst: public instruction {
protected:
io_inst(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
// value *get_mask() const;
// value *get_false_value() const;
};
class load_inst: public io_inst{
class load_inst: public io_inst {
protected:
load_inst(value *ptr, unsigned num_extra_ops, const std::string &name, instruction *next);
@@ -372,15 +391,15 @@ private:
static type *get_pointee_type(type *ty);
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
// factory method
static load_inst* create(value *ptr,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(load_inst)
};
class masked_load_inst: public load_inst{
class masked_load_inst: public load_inst {
private:
std::string repr_impl() const { return "masked_load"; }
masked_load_inst(value *ptr, value *mask, value *false_value,
@@ -394,6 +413,7 @@ public:
static masked_load_inst* create(value *ptr, value *mask, value *false_value,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_load_inst)
};
class store_inst: public io_inst{
@@ -406,12 +426,12 @@ private:
public:
// accessors
value *get_pointer_operand() { return get_operand(0); }
value *get_value_operand() { return get_operand(1); }
// factory method
static store_inst* create(value* ptr, value *v,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(store_inst)
};
class masked_store_inst: public store_inst{
@@ -427,6 +447,7 @@ public:
static masked_store_inst* create(value *ptr, value *v, value *mask,
const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(masked_store_inst)
};
//===----------------------------------------------------------------------===//
@@ -450,6 +471,7 @@ private:
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
};
// splat
@@ -462,6 +484,7 @@ private:
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
};
// broadcast
@@ -474,6 +497,7 @@ private:
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
};
@@ -486,6 +510,7 @@ private:
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(downcast_inst)
};
//===----------------------------------------------------------------------===//
@@ -505,6 +530,7 @@ private:
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_program_id_inst)
private:
unsigned axis_;
@@ -518,6 +544,7 @@ private:
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_program_inst)
private:
unsigned axis_;
@@ -527,6 +554,7 @@ class atomic_cas_inst: public builtin_inst {
private:
atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
std::string repr_impl() const { return "atomic_cas"; }
_TRITON_DEFINE_CLONE(atomic_cas_inst)
public:
static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -536,6 +564,7 @@ class atomic_exch_inst: public builtin_inst {
private:
atomic_exch_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_exch"; }
_TRITON_DEFINE_CLONE(atomic_exch_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -545,6 +574,7 @@ class atomic_add_inst: public builtin_inst {
private:
atomic_add_inst(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
std::string repr_impl() const { return "atomic_add"; }
_TRITON_DEFINE_CLONE(atomic_add_inst)
public:
static instruction* create(value *ptr, value *val, const std::string &name = "", instruction *next = nullptr);
@@ -566,6 +596,7 @@ public:
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
bool is_a_trans() { return AT_ == Trans; }
bool is_b_trans() { return BT_ == Trans; }
_TRITON_DEFINE_CLONE(dot_inst)
private:
TransT AT_;
@@ -586,17 +617,12 @@ public:
private:
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const {
std::string res = "trans<";
//for(ir::constant_int *x: perm_)
// res += x->repr() + ",";
res[res.size()-1] = '>';
return res;
}
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<constant_int*> get_perm() const;
_TRITON_DEFINE_CLONE(trans_inst)
private:
std::vector<constant_int*> perm_;
@@ -608,6 +634,7 @@ private:
std::string repr_impl() const { return "sqrt"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(sqrt_inst)
};
class reduce_inst: public builtin_inst {
@@ -617,6 +644,7 @@ private:
private:
reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next);
std::string repr_impl() const { return "reduce"; }
_TRITON_DEFINE_CLONE(reduce_inst)
public:
static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr);
@@ -630,6 +658,7 @@ class select_inst: public builtin_inst {
private:
select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
std::string repr_impl() const { return "select"; }
_TRITON_DEFINE_CLONE(select_inst)
public:
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
@@ -647,12 +676,14 @@ private:
public:
static copy_to_shared_inst* create(value *arg, const std::string &name = "",
instruction *next = nullptr);
_TRITON_DEFINE_CLONE(copy_to_shared_inst)
};
class barrier_inst: public instruction{
private:
barrier_inst(context &ctx, const std::string &name, instruction *next);
std::string repr_impl() const { return "barrier"; }
_TRITON_DEFINE_CLONE(barrier_inst)
public:
static barrier_inst* create(context &ctx, const std::string &name = "",
@@ -663,6 +694,7 @@ class vectorize_inst: public unary_inst{
private:
using unary_inst::unary_inst;
std::string repr_impl() const { return "vectorize"; }
_TRITON_DEFINE_CLONE(vectorize_inst)
public:
static vectorize_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
@@ -675,6 +707,7 @@ class nv_dynamic_program_idx_inst: public instruction {
private:
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
_TRITON_DEFINE_CLONE(nv_dynamic_program_idx_inst)
public:
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);

View File

@@ -5,6 +5,8 @@
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
#include <iostream>
#include <numeric>
#include <algorithm>
namespace triton {
namespace codegen{
@@ -36,258 +38,448 @@ bool align::is_first_axis_unit(ir::value *x){
return true;
}
align::cst_info align::populate_is_constant(ir::value *v) {
/*
* is constant
*/
std::vector<unsigned> align::get_shapes(ir::value *v) {
ir::type *ty = v->get_type();
if(ty->is_tile_ty())
return ty->get_tile_shapes();
else
return {1};
}
std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto it = is_constant_.find(inc);
if(it != is_constant_.end())
result = it->second;
}
return add_to_cache(x, result, is_constant_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
auto cst = populate_is_constant(inc);
for(size_t d = 0; d < cst.size(); d++)
result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
auto shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* op = x->get_operand(0);
auto op_cst = populate_is_constant(op);
for(auto d: shapes)
result.push_back(cst_info{d, op_cst[0].value});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < x_shapes.size(); d ++){
cst_info ax ;
if(x_shapes[d] == 1)
ax = {1, op_cst[current].value};
else if(!is_skewed
&& x_shapes[d] == op_shapes[current])
ax = {x_shapes[d], op_cst[current++].value};
else {
is_skewed = true;
ax = {x_shapes[d], 0};
}
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_cst = populate_is_constant(op);
for(size_t d = 0; d < x_shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(cst_info{x_shapes[d], op_cst[d].value});
else
result.push_back(op_cst[d]);
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
auto x_shapes = get_shapes(x);
std::vector<cst_info> result;
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
auto max_contiguous = populate_max_contiguous(lhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
ax = {num_constants, 0};
}
else
ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
result.push_back(ax);
}
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
auto x_shapes = get_shapes(x);
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
auto lhs = populate_is_constant(lhs_op);
auto rhs = populate_is_constant(rhs_op);
std::vector<cst_info> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
return add_to_cache(x, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
auto shapes = get_shapes(v);
std::vector<cst_info> result(shapes.size(), {1, 0});
return add_to_cache(v, result, is_constant_);
}
std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
if(is_constant_.find(v) != is_constant_.end())
return is_constant_.at(v);
// helper for the cache
auto cache = [this,v](cst_info value){
return add_to_cache(v, value, is_constant_); }
;
// populate
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
ir::value *op = x->get_operand(0);
auto op_cst = populate_is_constant(op);
if(is_first_axis_unit(op)){
unsigned num_cst = x->get_type()->get_tile_shapes()[0];
return cache({num_cst, op_cst.value});
}
}
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return cache({true, (unsigned)x->get_value()});
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
cst_info lhs = populate_is_constant(lhs_op);
cst_info rhs = populate_is_constant(rhs_op);
if(lhs.num_cst==0 && rhs.value && x->is_int_div()){
unsigned max_contiguous = populate_max_contiguous(lhs_op);
// todo might not be entirely true
unsigned num_constants = gcd(max_contiguous, rhs.value);
return cache({num_constants, 0});
}
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs_op = x->get_operand(0);
ir::value* rhs_op = x->get_operand(1);
cst_info lhs = populate_is_constant(lhs_op);
cst_info rhs = populate_is_constant(rhs_op);
return cache({std::min(lhs.num_cst, rhs.num_cst), 0});
}
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
// cst_info value_true = populate_is_constant(x->get_value_true());
// cst_info value_false = populate_is_constant(x->get_value_false());
// return cache({std::min(value_true.num_cst, value_false.num_cst), 0});
// }
if(v->get_type()->is_tile_ty())
return cache({0, 0});
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_is_constant_phi(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_is_constant_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_is_constant_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_is_constant_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_is_constant_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_is_constant_gep(x);
return populate_is_constant_default(v);
}
/*
* max contiguous
*/
std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(is_constant_.find(inc) != is_constant_.end())
result = is_constant_.at(inc).num_cst;
auto it = max_contiguous_.find(inc);
if(it != max_contiguous_.end())
result = it->second;
}
cache({result, 0});
add_to_cache(x, result, max_contiguous_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = std::min(result, populate_is_constant(inc).num_cst);
auto contiguous = populate_max_contiguous(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = std::min(result[d], contiguous[d]);
}
return cache({result, 0});
}
// scalars are always constant in the contiguous dimension
// but value is not known at compile-time
return cache({1, 0});
return add_to_cache(x, result, max_contiguous_);
}
unsigned align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
// helper for the cache
auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); };
// populate
if(!v->get_type()->is_tile_ty())
return cache(1);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::constant_range*>(v)){
return cache(shapes[0]);
}
if(auto *x = dynamic_cast<ir::retile_inst*>(v)){
std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
auto x_shapes = get_shapes(x);
std::vector<unsigned> result;
for(size_t d = 0; d < x_shapes.size(); d++)
result.push_back({1});
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
if(op->get_type()->is_tile_ty()){
auto op_shapes = op->get_type()->get_tile_shapes();
if(op_shapes[0] == shapes[0])
return cache(populate_max_contiguous(op));
auto op_mc = populate_max_contiguous(op);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result.push_back(1);
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result.push_back(op_mc[current++]);
else {
is_skewed = true;
result.push_back(1);
}
return cache(1);
}
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
auto shapes = get_shapes(x);
std::vector<unsigned> result;
ir::value *op = x->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
auto op_mc = populate_max_contiguous(op);
for(size_t d = 0; d < shapes.size(); d++)
if(op_shapes[d] == 1)
result.push_back(1);
else
result.push_back(op_mc[d]);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
cst_info lhs_cst_info = populate_is_constant(lhs);
cst_info rhs_cst_info = populate_is_constant(rhs);
if(x->is_int_rem() && rhs_cst_info.value > 0)
return cache(std::min(lhs_max_contiguous, rhs_cst_info.value));
if(x->is_int_mult()){
if(rhs_cst_info.value == 1)
return cache(lhs_max_contiguous);
if(lhs_cst_info.value == 1)
return cache(rhs_max_contiguous);
}
if(x->is_int_add_sub()){
if(lhs_cst_info.num_cst)
return cache(gcd(rhs_max_contiguous, lhs_cst_info.num_cst));
if(rhs_cst_info.num_cst)
return cache(gcd(lhs_max_contiguous, rhs_cst_info.num_cst));
}
}
// if(auto *x = dynamic_cast<ir::psi_inst*>(v)){
// int value_true = populate_max_contiguous(x->get_value_true());
// int value_false = populate_max_contiguous(x->get_value_false());
// return cache(std::min(value_true, value_false));
// }
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
unsigned lhs_max_contiguous = populate_max_contiguous(lhs);
unsigned rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
if(lhs_cst_info.num_cst)
return cache(rhs_max_contiguous);
if(rhs_cst_info.num_cst)
return cache(lhs_max_contiguous);
std::vector<unsigned> result;
for(size_t d = 0; d < shapes.size(); d++){
unsigned value = 1;
if(x->is_int_rem() && rhs_cst_info[d].value > 0)
value = std::min(lhs_max_contiguous[d], rhs_cst_info[d].value);
if(x->is_int_mult()){
unsigned lvalue = 1, rvalue = 1;
if(rhs_cst_info[d].value == 1)
lvalue = lhs_max_contiguous[d];
if(lhs_cst_info[d].value == 1)
rvalue = rhs_max_contiguous[d];
value = std::max(lvalue, rvalue);
}
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(max_contiguous_.find(inc) != max_contiguous_.end())
result = max_contiguous_.at(inc);
if(x->is_int_add_sub()){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
if(rhs_cst_info[d].num_cst)
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
value = std::max(lvalue, rvalue);
}
cache(result);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = std::min(result, populate_max_contiguous(inc));
result.push_back(value);
}
return cache(result);
}
return cache(1);
return add_to_cache(x, result, max_contiguous_);
}
unsigned align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
auto cache = [this,v](unsigned value){
return add_to_cache(v, value, starting_multiple_);
};
// has metadata
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return cache(multiple_of);
std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
auto shapes = get_shapes(x);
ir::value* lhs = x->get_operand(0);
ir::value* rhs = x->get_operand(1);
auto lhs_max_contiguous = populate_max_contiguous(lhs);
auto rhs_max_contiguous = populate_max_contiguous(rhs);
auto lhs_cst_info = populate_is_constant(lhs);
auto rhs_cst_info = populate_is_constant(rhs);
std::vector<unsigned> result(shapes.size(), 1);
for(size_t d = 0; d < shapes.size(); d++){
unsigned lvalue = 1, rvalue = 1;
if(lhs_cst_info[d].num_cst)
lvalue = rhs_max_contiguous[d];
if(rhs_cst_info[d].num_cst)
rvalue = lhs_max_contiguous[d];
result[d] = std::max(lvalue, rvalue);
}
// arguments
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of)
return cache(attr.get_value());
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
return cache(attr.get_value() / nbytes);
return add_to_cache(x, result, max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
if(!v->get_type()->is_tile_ty())
return add_to_cache(v, {1}, max_contiguous_);
auto shapes = v->get_type()->get_tile_shapes();
if(dynamic_cast<ir::constant_range*>(v))
return add_to_cache(v, {shapes[0]}, max_contiguous_);
return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
}
std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
if(max_contiguous_.find(v) != max_contiguous_.end())
return max_contiguous_.at(v);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_max_contiguous_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_max_contiguous_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_max_contiguous_broadcast(x);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_max_contiguous_binop(x);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_max_contiguous_gep(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_max_contiguous_phi(x);
return populate_max_contiguous_default(v);
}
/*
* starting multiple
*/
std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
auto shapes = get_shapes(x);
auto op = populate_starting_multiple(x->get_operand(0));
std::vector<unsigned> result(shapes.size(), op[0]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
auto op = populate_starting_multiple(x->get_operand(0));
auto op_shapes = get_shapes(x->get_operand(0));
auto shapes = get_shapes(x);
std::vector<unsigned> result(shapes.size(), 1);
unsigned current = 0;
bool is_skewed = false;
for(size_t d = 0; d < shapes.size(); d ++){
if(shapes[d] == 1)
result[d] = 1;
else if(!is_skewed
&& shapes[d] == op_shapes[current])
result[d] = op[current++];
else {
is_skewed = true;
result[d] = 1;
}
}
}
if(auto *x = dynamic_cast<ir::binary_operator*>(v)){
int lhs = populate_starting_multiple(x->get_operand(0));
int rhs = populate_starting_multiple(x->get_operand(1));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
auto result = populate_starting_multiple(x->get_operand(0));
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++){
if(x->is_int_mult())
return cache(lhs * rhs);
result[d] = lhs[d] * rhs[d];
if(x->is_int_add_sub())
return cache(gcd(lhs, rhs));
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div())
return cache(std::max(lhs / rhs, 1));
if(x->is_int_rem() && rhs > 1)
return cache(gcd(lhs, rhs));
result[d] = std::max<unsigned>(lhs[d] / rhs[d], 1);
if(x->is_int_rem() && rhs[d] > 1)
result[d] = gcd(lhs[d], rhs[d]);
if(x->is_shl())
return cache(lhs << rhs);
result[d] = lhs[d] << rhs[d];
if(x->is_shr())
return cache(std::max(lhs >> rhs, 1));
result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
}
if(auto *x = dynamic_cast<ir::constant_int*>(v)){
return cache(x->get_value());
}
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
return cache(x->get_first()->get_value());
}
if(dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
return cache(128);
}
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){
return cache(x->get_range()->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
int lhs = populate_starting_multiple(x->get_operand(0));
int rhs = populate_starting_multiple(x->get_operand(1));
return cache(gcd(lhs, rhs));
}
if(auto *x = dynamic_cast<ir::splat_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
if(auto *x = dynamic_cast<ir::reshape_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
auto shapes = x->get_type()->get_tile_shapes();
if(shapes[0] == 1)
return cache(1);
else
return cache(op);
}
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)){
int op = populate_starting_multiple(x->get_operand(0));
return cache(op);
}
if(auto *x = dynamic_cast<ir::phi_node*>(v)){
// put a conservative initial value in phi node to avoid infinite recursion
unsigned result = 1;
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
auto lhs = populate_starting_multiple(x->get_operand(0));
auto rhs = populate_starting_multiple(x->get_operand(1));
std::vector<unsigned> result(lhs.size(), 1);
for(size_t d = 0; d < lhs.size(); d++)
result[d] = gcd(lhs[d], rhs[d]);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
auto shape = get_shapes(x);
std::vector<unsigned> result(shape.size(), 1);
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
if(starting_multiple_.find(inc) != starting_multiple_.end())
result = starting_multiple_.at(inc);
}
cache(result);
add_to_cache(x, result, starting_multiple_);
// recurse
for(unsigned n = 0; n < x->get_num_incoming(); n++){
ir::value* inc = x->get_incoming_value(n);
result = gcd(result, populate_starting_multiple(inc));
auto sm = populate_starting_multiple(inc);
for(size_t d = 0; d < result.size(); d++)
result[d] = gcd(result[d], sm[d]);
}
return cache(result);
return add_to_cache(x, result, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
ir::type* ty = v->get_type();
if(ty->is_tile_ty()) {
return add_to_cache(v, ty->get_tile_shapes(), starting_multiple_);
}
// scalars
if(!v->get_type()->is_tile_ty())
return cache(1);
// tiles
auto shapes = v->get_type()->get_tile_shapes();
unsigned result = 1;
for(unsigned i = 0; i < shapes.size() - 1; i++)
result *= shapes[i];
return cache(result);
if(auto *x = dynamic_cast<ir::instruction*>(v)){
unsigned multiple_of = x->get_metadata(ir::metadata::multiple_of);
if(multiple_of > 0)
return add_to_cache(x, {multiple_of}, starting_multiple_);
}
if(auto *x = dynamic_cast<ir::argument*>(v)){
std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
for(auto attr: attributes){
if(attr.get_kind() == ir::multiple_of){
return add_to_cache(x, {attr.get_value()}, starting_multiple_);
}
if(attr.get_kind() == ir::aligned){
ir::type* ty = x->get_type()->get_pointer_element_ty();
int nbits = ty->get_primitive_size_in_bits();
int nbytes = nbits / 8;
return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
}
}
}
return add_to_cache(v, {1}, starting_multiple_);
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v);
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
return populate_starting_multiple_binop(x);
if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::constant_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v))
return add_to_cache(x, {128}, starting_multiple_);
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v))
return add_to_cache(x, {(unsigned)x->get_range()->get_first()->get_value()}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
return populate_starting_multiple_splat(x);
if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
return populate_starting_multiple_reshape(x);
if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
return populate_starting_multiple_broadcast(x);
if(auto *x = dynamic_cast<ir::phi_node*>(v))
return populate_starting_multiple_phi(x);
return populate_starting_multiple_default(v);
}
unsigned align::get_starting_multiple(ir::value* v) const {
return starting_multiple_.at(v);
return starting_multiple_.at(v)[0];
}
unsigned align::get_max_contiguous(ir::value* v) const {
return max_contiguous_.at(v)[0];
}
std::vector<unsigned> align::get_max_contiguous_vec(ir::value* v) const {
return max_contiguous_.at(v);
}
@@ -297,7 +489,7 @@ void align::copy(ir::value *dst, ir::value *src) {
is_constant_[dst] = is_constant_[src];
}
///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN
void align::run(ir::module &mod) {
// populate constant
for(ir::function *fn: mod.get_function_list())
@@ -316,13 +508,9 @@ void align::run(ir::module &mod) {
// populate maximum contiguous
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
for(ir::instruction *i: block->get_inst_list()){
populate_max_contiguous(i);
// for(ir::function *fn: mod.get_function_list())
// for(ir::basic_block *block: fn->blocks())
// for(ir::instruction *i: block->get_inst_list())
// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl;
}
}

View File

@@ -76,16 +76,16 @@ void grids::init_c_graph(ir::instruction *v) {
// Reshape
if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0);
auto op_shapes = op->get_type()->get_tile_shapes();
unsigned current = 0;
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){
bool is_one = shapes[i] == 1;
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
if(is_one){
if(shapes[i] == 1){
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i});
}
else if(!is_skewed && is_same)
else if(!is_skewed &&
shapes[i] == op_shapes[current])
add_constraint({v, i}, {op, current++});
else{
is_skewed = true;
@@ -130,13 +130,10 @@ void grids::init_c_graph(ir::instruction *v) {
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
for(unsigned k = 0; k < v->get_num_results(); k++){
ir::value *result = v->get_result(k);
for(unsigned i = 0; i < shapes.size(); i ++){
std::vector<ir::value*> ops = v->ops();
for(ir::value* op: ops)
add_constraint({result, i}, {op, i});
}
add_constraint({v, i}, {op, i});
}
}
}

View File

@@ -864,10 +864,6 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
std::map<unsigned, ir::value*> references;
create_grids(grids, references, fn);
for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i))
for(unsigned r = 0; r < instr->get_num_results(); r++)
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
else
init_axes(i, builder, u_thread_warp_id, u_warp_id);
}
// create tile
@@ -876,8 +872,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
for(unsigned r = 0; r < i->get_num_results(); r++)
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
create_tile(i, builder, references, seen, sh_mem_ptr);
}
}

View File

@@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -20,11 +20,6 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const
auto it = std::find(block->begin(), block->end(), next);
block->get_inst_list().insert(it, next);
}
if(num_results == 1)
results_.push_back(this);
else
for(unsigned i = 0; i < num_results; i++)
results_.push_back(new result_reference(this, i));
}
void instruction::erase_from_parent() {

View File

@@ -48,14 +48,8 @@ void print(module &mod, std::ostream& os) {
os << std::endl;
for(ir::instruction *inst: block->get_inst_list()){
os << " ";
unsigned num_results = inst->get_num_results();
for(unsigned i = 0; i < num_results; i++){
os << get_name(inst->get_result(i), cnt++);
if(i < num_results - 1)
os << ", ";
else
os << get_name(inst, cnt++);
os << " = ";
}
ir::type* type = inst->get_type();
os << inst->repr() << " " << type->repr();
ir::instruction::ops_t ops = inst->ops();

View File

@@ -5,6 +5,7 @@
#include <algorithm>
#include "triton/codegen/selection.h"
#include "triton/runtime/function.h"
#include "triton/codegen/transform/reorder.h"
#include "triton/lang/cpp.h"
#include "triton/lang/parser.h"
#include "triton/lang/code_gen.h"
@@ -198,6 +199,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::analysis::liveness shmem_liveness(&shmem_info);
codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids);
codegen::analysis::align alignment_info;
codegen::transform::reorder reorder(&alignment_info);
codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&grids);
codegen::transform::dce dce;
@@ -208,6 +210,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
peephole.run(module);
dce.run(module);
alignment_info.run(module);
ir::print(module, std::cout);
// reorder.run(module);
dce.run(module);
ir::print(module, std::cout);
grids.run(module);
reassociate.run(module);
dce.run(module);

View File

@@ -38,7 +38,7 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16),
int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx;
TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy;
TYPE* py[TM, TN] = Y + rm[:, newaxis] * ldy + rn[newaxis, :];
*py = *px;
}
)";