[LANG] Added support for device functions (#484)
This commit is contained in:
@@ -9,17 +9,12 @@
|
||||
namespace triton{
|
||||
namespace ir{
|
||||
|
||||
/* Module */
|
||||
module::module(const std::string &name, builder &builder)
|
||||
: name_(name), builder_(builder) {
|
||||
/* */
|
||||
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
|
||||
sealed_blocks_.insert(nullptr);
|
||||
}
|
||||
|
||||
ir::builder& module::get_builder() {
|
||||
return builder_;
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
|
||||
values_[val_key_t{name, block}] = value;
|
||||
auto it = metadatas_.find(name);
|
||||
if(auto *x = dynamic_cast<ir::instruction*>(value))
|
||||
@@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
|
||||
// value->set_name(name);
|
||||
}
|
||||
|
||||
void module::set_value(const std::string& name, ir::value *value){
|
||||
void value_constructor::set_value(const std::string& name, ir::value *value){
|
||||
return set_value(name, builder_.get_insert_block(), value);
|
||||
}
|
||||
|
||||
void module::set_const(const std::string& name){
|
||||
const_.insert(name);
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
|
||||
basic_block::iterator insert = block->get_first_non_phi();
|
||||
if(insert != block->end()){
|
||||
builder_.set_insert_point(insert);
|
||||
@@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
|
||||
return res;
|
||||
}
|
||||
|
||||
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
// find non-self references
|
||||
std::set<ir::value*> non_self_ref;
|
||||
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
|
||||
@@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
assert(same != nullptr);
|
||||
phi->replace_all_uses_with(same);
|
||||
phi->erase_from_parent();
|
||||
std::set<ir::user*> users = phi->get_users();
|
||||
std::vector<ir::user*> users = phi->get_users();
|
||||
for(ir::user* u: users)
|
||||
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
|
||||
if(uphi != phi)
|
||||
@@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
|
||||
}
|
||||
|
||||
|
||||
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
|
||||
// already initialized
|
||||
if(phi->get_num_operands())
|
||||
return phi;
|
||||
@@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
|
||||
return phi;
|
||||
}
|
||||
|
||||
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *result;
|
||||
bool is_const = const_.find(name) != const_.end();
|
||||
auto &preds = block->get_predecessors();
|
||||
auto preds = block->get_predecessors();
|
||||
ir::type *ty = types_.at(name);
|
||||
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||
result = (ir::value*)incomplete_phis_[block][name];
|
||||
}
|
||||
@@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
|
||||
ir::basic_block* save_block = builder_.get_insert_block();
|
||||
ir::basic_block::iterator save_pt = builder_.get_insert_point();
|
||||
val_key_t key(name, block);
|
||||
// std::cout << values_.size() << std::endl;
|
||||
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
|
||||
if(values_.find(key) != values_.end()){
|
||||
return values_.at(key);
|
||||
}
|
||||
@@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
|
||||
return result;
|
||||
}
|
||||
|
||||
ir::value *module::get_value(const std::string& name) {
|
||||
ir::value *value_constructor::get_value(const std::string& name) {
|
||||
return get_value(name, builder_.get_insert_block());
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
void value_constructor::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
add_phi_operands(x.first, x.second);
|
||||
if(get_value(x.first) == x.second)
|
||||
@@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){
|
||||
incomplete_phis_[block].clear();
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Module */
|
||||
|
||||
module::module(const std::string &name, builder &builder)
|
||||
: name_(name), builder_(builder) {
|
||||
}
|
||||
|
||||
void module::reset_ret_ty(const std::string& name, type* ty) {
|
||||
get_function(name)->get_fn_type()->reset_ret_ty(ty);
|
||||
}
|
||||
|
||||
ir::builder& module::get_builder() {
|
||||
return builder_;
|
||||
}
|
||||
|
||||
void module::set_continue_fn(std::function<ir::value*()> fn) {
|
||||
continue_fn_ = fn;
|
||||
}
|
||||
|
||||
std::function<ir::value*()> module::get_continue_fn() {
|
||||
return continue_fn_;
|
||||
}
|
||||
|
||||
const std::string& module::get_name() {
|
||||
return name_;
|
||||
}
|
||||
|
||||
/* functions */
|
||||
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
|
||||
function *&fn = (function*&)symbols_[name];
|
||||
if(fn == nullptr)
|
||||
return fn = function::create(ty, global_value::external, name, this);
|
||||
if(fn == nullptr){
|
||||
fn = function::create(ty, global_value::external, name, this);
|
||||
}
|
||||
return fn;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user