diff --git a/include/triton/ir/value.h b/include/triton/ir/value.h index e192a54ef..e1599d6bc 100644 --- a/include/triton/ir/value.h +++ b/include/triton/ir/value.h @@ -20,15 +20,18 @@ class visitor; //===----------------------------------------------------------------------===// class value { +public: + typedef std::set users_t; + public: // constructor value(type *ty, const std::string &name = ""); virtual ~value(){ } // uses void add_use(user* arg); - unsigned erase_use(user* arg); + users_t::iterator erase_use(user* arg); const std::set &get_users() { return users_; } - virtual void replace_all_uses_with(value *target); + void replace_all_uses_with(value *target); // name void set_name(const std::string &name); const std::string &get_name() const { return name_; } @@ -41,7 +44,7 @@ private: protected: type *ty_; - std::set users_; + users_t users_; }; //===----------------------------------------------------------------------===// @@ -63,6 +66,7 @@ public: user(type *ty, unsigned num_ops, const std::string &name = "") : value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){ } + virtual ~user() { } // Operands const ops_t& ops() { return ops_; } @@ -74,8 +78,7 @@ public: unsigned get_num_hidden() const; // Utils - void replace_all_uses_with(value *target); - void replace_uses_of_with(value *before, value *after); + value::users_t::iterator replace_uses_of_with(value *before, value *after); private: diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 67617478b..28edb4e4f 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -71,9 +71,9 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ // unique value or self-reference ir::value *same = *non_self_ref.begin(); assert(same != nullptr); - std::set users = phi->get_users(); phi->replace_all_uses_with(same); phi->erase_from_parent(); + std::set users = phi->get_users(); for(ir::user* u: users) if(auto *uphi = dynamic_cast(u)) if(uphi != phi) diff --git a/lib/ir/value.cc b/lib/ir/value.cc index a43aaa05e..81f803df6 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -1,4 +1,5 @@ #include +#include #include "triton/ir/value.h" #include "triton/ir/instructions.h" @@ -19,8 +20,11 @@ void value::add_use(user *arg) { users_.insert(arg); } -unsigned value::erase_use(user *arg){ - return users_.erase(arg); +value::users_t::iterator value::erase_use(user *arg){ + auto it = users_.find(arg); + if(it == users_.end()) + return it; + return users_.erase(it); } // TODO: automatic naming scheme + update symbol table @@ -29,9 +33,12 @@ void value::set_name(const std::string &name){ } void value::replace_all_uses_with(value *target){ - throw std::runtime_error("not implemented"); + for (auto it = users_.begin(); it != users_.end(); ) { + it = (*it)->replace_uses_of_with(this, target); + } } + void visitor::visit_value(ir::value* v) { v->accept(this); } @@ -59,18 +66,12 @@ unsigned user::get_num_hidden() const { return num_hidden_; } -void user::replace_all_uses_with(value *target) { - for(auto it = users_.begin(); it != users_.end(); it++){ - (*it)->replace_uses_of_with(this, target); - } -} - -void user::replace_uses_of_with(value *before, value *after) { +value::users_t::iterator user::replace_uses_of_with(value *before, value *after) { for(size_t i = 0; i < ops_.size(); i++) if(ops_[i] == before) ops_[i] = after; after->add_use(this); - before->erase_use(this); + return before->erase_use(this); } diff --git a/python/setup.py b/python/setup.py index d867dc122..cb6b1bcd6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -100,7 +100,7 @@ for d in directories: setup( name='triton', - version='0.1', + version='0.1.1', author='Philippe Tillet', author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations',