[ir][instructions] added permutations option for trans

This commit is contained in:
Philippe Tillet
2019-08-05 21:19:13 -07:00
parent d62e581ab3
commit 26c9849462
10 changed files with 58 additions and 22 deletions

View File

@@ -13,13 +13,14 @@ namespace ir {
class instruction;
class trans_inst;
class builder;
class constant_int;
}
namespace codegen{
class optimize_trans {
private:
ir::value *replace_phi(ir::value* value, ir::builder &builder);
ir::value *replace_phi(ir::value* value, ir::builder &builder, const std::vector<ir::constant_int *> &perm);
public:
optimize_trans() {}

View File

@@ -132,7 +132,7 @@ public:
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::string &name = "");
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");

View File

@@ -585,13 +585,18 @@ private:
class trans_inst: public builtin_inst {
public:
ir::type* get_res_ty(ir::type* in);
std::vector<constant_int*> get_default_perm(ir::type* ty);
private:
trans_inst(value *arg, const std::string& name, instruction* next);
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
std::string repr_impl() const { return "trans"; }
public:
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
const std::vector<constant_int*> get_perm() const;
private:
std::vector<constant_int*> perm_;
};
class sqrt_inst: public builtin_inst {

View File

@@ -68,7 +68,13 @@ void optimize_dot::run(ir::module &mod) {
}
// dot(op(a), b)
if(!trans_b){
ir::value* BB = builder.create_trans(B);
size_t size = B->get_type()->get_tile_shapes().size();
std::vector<ir::constant_int*> perm(size);
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
for(size_t i = 0; i < size; i++)
perm[i] = ir::constant_int::get(int32_ty, i);
std::swap(perm[0], perm[1]);
ir::value* BB = builder.create_trans(B, perm);
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
dot->replace_all_uses_with(NT);
to_delete.push_back(dot);

View File

@@ -7,12 +7,13 @@ namespace codegen{
ir::value* optimize_trans::replace_phi(ir::value* value,
ir::builder& builder){
ir::builder& builder,
const std::vector<ir::constant_int*> &perm){
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
// transpose operands
std::vector<ir::value*> incs;
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
incs.push_back(replace_phi(phi->get_incoming_value(n), builder));
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
@@ -26,7 +27,7 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
auto it = std::find(block->begin(), block->end(), i);
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
i->replace_all_uses_with(trans);
trans->set_operand(0, i);
return trans;
@@ -53,7 +54,7 @@ void optimize_trans::run(ir::module &mod) {
// trans(phi) -> phi(trans(), trans()...)
if(dynamic_cast<ir::phi_node*>(op)){
ir::value* new_phi = replace_phi(op, builder);
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
trans->replace_all_uses_with(new_phi);
}
}

View File

@@ -974,11 +974,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
});
}
// trans
else if(dynamic_cast<ir::trans_inst*>(ins)) {
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins)) {
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
auto perm = x->get_perm();
in->for_each([&](indices_t idx){
indices_t out_idx = idx;
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
indices_t out_idx(idx.size());
for(size_t i = 0; i < idx.size(); i++)
out_idx[i] = idx[perm[i]->get_value()];
ti->set_value(out_idx, in->get_value(idx));
});
}

View File

@@ -98,11 +98,11 @@ void tune::init_c_graph(ir::instruction *v) {
}
// Trans
else if(dynamic_cast<ir::trans_inst*>(v)){
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0);
size_t n_shapes = shapes.size();
for(unsigned i = 0; i < n_shapes; i++)
add_constraint({v, (i + 1) % n_shapes}, {op, i});
auto perm = x->get_perm();
for(unsigned i = 0; i < perm.size(); i++)
add_constraint({v, perm[i]->get_value()}, {op, i});
}
// Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){

View File

@@ -75,7 +75,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4";
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
@@ -84,11 +84,13 @@ void dot::triton_c_src(std::ostream &os) const {
std::string useb = BT_ ? "trans(xb)" : "xb";
if(AT_){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
@@ -144,6 +146,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
}
)";
std::cout << res << std::endl;
os << res;
}

View File

@@ -324,8 +324,8 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name
return insert(dot_inst::create_nn(A, B, C, name));
}
value *builder::create_trans(value *A, const std::string &name) {
return insert(trans_inst::create(A, name));
value *builder::create_trans(value *A, const std::vector<ir::constant_int*>& perm, const std::string &name) {
return insert(trans_inst::create(A, perm, name));
}
value *builder::create_sqrt(value *A, const std::string &name) {

View File

@@ -572,13 +572,31 @@ ir::type* trans_inst::get_res_ty(ir::type* ty) {
return tile_type::get(ty->get_scalar_ty(), shapes);
}
trans_inst::trans_inst(value *arg, const std::string &name, instruction *next)
std::vector<constant_int*> trans_inst::get_default_perm(ir::type* ty) {
auto size = ty->get_tile_shapes().size();
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
std::vector<constant_int*> result;
for(size_t i = 0; i < size; i++)
result.push_back(ir::constant_int::get(int32_ty, i + 1 % size));
return result;
}
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
perm_ = perm;
if(perm_.empty())
perm_ = get_default_perm(arg->get_type());
auto size = arg->get_type()->get_tile_shapes().size();
assert(perm_.size() == size);
set_operand(0, arg);
}
instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) {
return new trans_inst(arg, name, next);
instruction* trans_inst::create(value *arg, const std::vector<constant_int *> &perm, const std::string &name, instruction *next) {
return new trans_inst(arg, perm, name, next);
}
const std::vector<constant_int*> trans_inst::get_perm() const {
return perm_;
}
//===----------------------------------------------------------------------===//