[ir][instructions] added permutations option for trans
This commit is contained in:
@@ -13,13 +13,14 @@ namespace ir {
|
||||
class instruction;
|
||||
class trans_inst;
|
||||
class builder;
|
||||
class constant_int;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
|
||||
class optimize_trans {
|
||||
private:
|
||||
ir::value *replace_phi(ir::value* value, ir::builder &builder);
|
||||
ir::value *replace_phi(ir::value* value, ir::builder &builder, const std::vector<ir::constant_int *> &perm);
|
||||
|
||||
public:
|
||||
optimize_trans() {}
|
||||
|
@@ -132,7 +132,7 @@ public:
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_atomic_add(value *ptr, value *val, const std::string &name = "");
|
||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::string &name = "");
|
||||
value *create_trans(value *A, const std::vector<constant_int *> &perm = {}, const std::string &name = "");
|
||||
value *create_sqrt(value *A, const std::string &name = "");
|
||||
value *create_reduce(value *A, unsigned axis, const std::string &name = "");
|
||||
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
|
||||
|
@@ -585,13 +585,18 @@ private:
|
||||
class trans_inst: public builtin_inst {
|
||||
public:
|
||||
ir::type* get_res_ty(ir::type* in);
|
||||
std::vector<constant_int*> get_default_perm(ir::type* ty);
|
||||
|
||||
private:
|
||||
trans_inst(value *arg, const std::string& name, instruction* next);
|
||||
trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string& name, instruction* next);
|
||||
std::string repr_impl() const { return "trans"; }
|
||||
|
||||
public:
|
||||
static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
|
||||
static instruction* create(value *arg, const std::vector<constant_int*>& perm = {}, const std::string &name = "", instruction *next = nullptr);
|
||||
const std::vector<constant_int*> get_perm() const;
|
||||
|
||||
private:
|
||||
std::vector<constant_int*> perm_;
|
||||
};
|
||||
|
||||
class sqrt_inst: public builtin_inst {
|
||||
|
@@ -68,7 +68,13 @@ void optimize_dot::run(ir::module &mod) {
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!trans_b){
|
||||
ir::value* BB = builder.create_trans(B);
|
||||
size_t size = B->get_type()->get_tile_shapes().size();
|
||||
std::vector<ir::constant_int*> perm(size);
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||
for(size_t i = 0; i < size; i++)
|
||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||
std::swap(perm[0], perm[1]);
|
||||
ir::value* BB = builder.create_trans(B, perm);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back(dot);
|
||||
|
@@ -7,12 +7,13 @@ namespace codegen{
|
||||
|
||||
|
||||
ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
ir::builder& builder){
|
||||
ir::builder& builder,
|
||||
const std::vector<ir::constant_int*> &perm){
|
||||
if(auto phi = dynamic_cast<ir::phi_node*>(value)) {
|
||||
// transpose operands
|
||||
std::vector<ir::value*> incs;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), builder));
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
|
||||
@@ -26,7 +27,7 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
i->replace_all_uses_with(trans);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
@@ -53,7 +54,7 @@ void optimize_trans::run(ir::module &mod) {
|
||||
|
||||
// trans(phi) -> phi(trans(), trans()...)
|
||||
if(dynamic_cast<ir::phi_node*>(op)){
|
||||
ir::value* new_phi = replace_phi(op, builder);
|
||||
ir::value* new_phi = replace_phi(op, builder, trans->get_perm());
|
||||
trans->replace_all_uses_with(new_phi);
|
||||
}
|
||||
}
|
||||
|
@@ -974,11 +974,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
});
|
||||
}
|
||||
// trans
|
||||
else if(dynamic_cast<ir::trans_inst*>(ins)) {
|
||||
else if(auto* x = dynamic_cast<ir::trans_inst*>(ins)) {
|
||||
distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0));
|
||||
auto perm = x->get_perm();
|
||||
in->for_each([&](indices_t idx){
|
||||
indices_t out_idx = idx;
|
||||
std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end());
|
||||
indices_t out_idx(idx.size());
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
out_idx[i] = idx[perm[i]->get_value()];
|
||||
ti->set_value(out_idx, in->get_value(idx));
|
||||
});
|
||||
}
|
||||
|
@@ -98,11 +98,11 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
|
||||
}
|
||||
// Trans
|
||||
else if(dynamic_cast<ir::trans_inst*>(v)){
|
||||
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
size_t n_shapes = shapes.size();
|
||||
for(unsigned i = 0; i < n_shapes; i++)
|
||||
add_constraint({v, (i + 1) % n_shapes}, {op, i});
|
||||
auto perm = x->get_perm();
|
||||
for(unsigned i = 0; i < perm.size(); i++)
|
||||
add_constraint({v, perm[i]->get_value()}, {op, i});
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
|
@@ -75,7 +75,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
|
||||
std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4";
|
||||
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
@@ -84,11 +84,13 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string useb = BT_ ? "trans(xb)" : "xb";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT_){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
@@ -144,6 +146,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
}
|
||||
)";
|
||||
|
||||
std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
@@ -324,8 +324,8 @@ value *builder::create_dot(value *A, value *B, value *C, const std::string &name
|
||||
return insert(dot_inst::create_nn(A, B, C, name));
|
||||
}
|
||||
|
||||
value *builder::create_trans(value *A, const std::string &name) {
|
||||
return insert(trans_inst::create(A, name));
|
||||
value *builder::create_trans(value *A, const std::vector<ir::constant_int*>& perm, const std::string &name) {
|
||||
return insert(trans_inst::create(A, perm, name));
|
||||
}
|
||||
|
||||
value *builder::create_sqrt(value *A, const std::string &name) {
|
||||
|
@@ -572,13 +572,31 @@ ir::type* trans_inst::get_res_ty(ir::type* ty) {
|
||||
return tile_type::get(ty->get_scalar_ty(), shapes);
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::string &name, instruction *next)
|
||||
std::vector<constant_int*> trans_inst::get_default_perm(ir::type* ty) {
|
||||
auto size = ty->get_tile_shapes().size();
|
||||
ir::type* int32_ty = type::get_int32_ty(ty->get_context());
|
||||
std::vector<constant_int*> result;
|
||||
for(size_t i = 0; i < size; i++)
|
||||
result.push_back(ir::constant_int::get(int32_ty, i + 1 % size));
|
||||
return result;
|
||||
}
|
||||
|
||||
trans_inst::trans_inst(value *arg, const std::vector<constant_int*>& perm, const std::string &name, instruction *next)
|
||||
: builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) {
|
||||
perm_ = perm;
|
||||
if(perm_.empty())
|
||||
perm_ = get_default_perm(arg->get_type());
|
||||
auto size = arg->get_type()->get_tile_shapes().size();
|
||||
assert(perm_.size() == size);
|
||||
set_operand(0, arg);
|
||||
}
|
||||
|
||||
instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) {
|
||||
return new trans_inst(arg, name, next);
|
||||
instruction* trans_inst::create(value *arg, const std::vector<constant_int *> &perm, const std::string &name, instruction *next) {
|
||||
return new trans_inst(arg, perm, name, next);
|
||||
}
|
||||
|
||||
const std::vector<constant_int*> trans_inst::get_perm() const {
|
||||
return perm_;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Reference in New Issue
Block a user