[LANG] Fixed undefined behavior in replace_all_uses_with()

This commit is contained in:
Philippe Tillet
2020-05-11 12:15:56 -04:00
committed by Philippe Tillet
parent ddd89e1b22
commit 13ff6472e0
4 changed files with 22 additions and 18 deletions

View File

@@ -20,15 +20,18 @@ class visitor;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class value { class value {
public:
typedef std::set<user*> users_t;
public: public:
// constructor // constructor
value(type *ty, const std::string &name = ""); value(type *ty, const std::string &name = "");
virtual ~value(){ } virtual ~value(){ }
// uses // uses
void add_use(user* arg); void add_use(user* arg);
unsigned erase_use(user* arg); users_t::iterator erase_use(user* arg);
const std::set<user*> &get_users() { return users_; } const std::set<user*> &get_users() { return users_; }
virtual void replace_all_uses_with(value *target); void replace_all_uses_with(value *target);
// name // name
void set_name(const std::string &name); void set_name(const std::string &name);
const std::string &get_name() const { return name_; } const std::string &get_name() const { return name_; }
@@ -41,7 +44,7 @@ private:
protected: protected:
type *ty_; type *ty_;
std::set<user*> users_; users_t users_;
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -63,6 +66,7 @@ public:
user(type *ty, unsigned num_ops, const std::string &name = "") user(type *ty, unsigned num_ops, const std::string &name = "")
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){ : value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
} }
virtual ~user() { }
// Operands // Operands
const ops_t& ops() { return ops_; } const ops_t& ops() { return ops_; }
@@ -74,8 +78,7 @@ public:
unsigned get_num_hidden() const; unsigned get_num_hidden() const;
// Utils // Utils
void replace_all_uses_with(value *target); value::users_t::iterator replace_uses_of_with(value *before, value *after);
void replace_uses_of_with(value *before, value *after);
private: private:

View File

@@ -71,9 +71,9 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
// unique value or self-reference // unique value or self-reference
ir::value *same = *non_self_ref.begin(); ir::value *same = *non_self_ref.begin();
assert(same != nullptr); assert(same != nullptr);
std::set<ir::user*> users = phi->get_users();
phi->replace_all_uses_with(same); phi->replace_all_uses_with(same);
phi->erase_from_parent(); phi->erase_from_parent();
std::set<ir::user*> users = phi->get_users();
for(ir::user* u: users) for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u)) if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi) if(uphi != phi)

View File

@@ -1,4 +1,5 @@
#include <cassert> #include <cassert>
#include <iostream>
#include "triton/ir/value.h" #include "triton/ir/value.h"
#include "triton/ir/instructions.h" #include "triton/ir/instructions.h"
@@ -19,8 +20,11 @@ void value::add_use(user *arg) {
users_.insert(arg); users_.insert(arg);
} }
unsigned value::erase_use(user *arg){ value::users_t::iterator value::erase_use(user *arg){
return users_.erase(arg); auto it = users_.find(arg);
if(it == users_.end())
return it;
return users_.erase(it);
} }
// TODO: automatic naming scheme + update symbol table // TODO: automatic naming scheme + update symbol table
@@ -29,9 +33,12 @@ void value::set_name(const std::string &name){
} }
void value::replace_all_uses_with(value *target){ void value::replace_all_uses_with(value *target){
throw std::runtime_error("not implemented"); for (auto it = users_.begin(); it != users_.end(); ) {
it = (*it)->replace_uses_of_with(this, target);
}
} }
void visitor::visit_value(ir::value* v) { void visitor::visit_value(ir::value* v) {
v->accept(this); v->accept(this);
} }
@@ -59,18 +66,12 @@ unsigned user::get_num_hidden() const {
return num_hidden_; return num_hidden_;
} }
void user::replace_all_uses_with(value *target) { value::users_t::iterator user::replace_uses_of_with(value *before, value *after) {
for(auto it = users_.begin(); it != users_.end(); it++){
(*it)->replace_uses_of_with(this, target);
}
}
void user::replace_uses_of_with(value *before, value *after) {
for(size_t i = 0; i < ops_.size(); i++) for(size_t i = 0; i < ops_.size(); i++)
if(ops_[i] == before) if(ops_[i] == before)
ops_[i] = after; ops_[i] = after;
after->add_use(this); after->add_use(this);
before->erase_use(this); return before->erase_use(this);
} }

View File

@@ -100,7 +100,7 @@ for d in directories:
setup( setup(
name='triton', name='triton',
version='0.1', version='0.1.1',
author='Philippe Tillet', author='Philippe Tillet',
author_email='ptillet@g.harvard.edu', author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations', description='A language and compiler for custom Deep Learning operations',