[LANG] Added support for device functions (#484)

This commit is contained in:
Philippe Tillet
2022-04-03 20:58:16 -07:00
committed by GitHub
parent e85c7a7fc7
commit 2bed6fc850
39 changed files with 1213 additions and 379 deletions

View File

@@ -1,3 +1,5 @@
#include <iostream>
#include <algorithm>
#include "triton/ir/basic_block.h"
#include "triton/ir/instructions.h"
#include "triton/ir/type.h"
@@ -9,23 +11,68 @@ namespace ir {
class phi_node;
basic_block::basic_block(context &ctx, const std::string &name, function *parent):
basic_block::basic_block(context &ctx, const std::string &name, function *parent, basic_block* next):
value(type::get_label_ty(ctx), name), ctx_(ctx), parent_(parent) {
if(parent_)
parent_->insert_block(this);
parent_->insert_block(this, next);
}
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent){
return new basic_block(ctx, name, parent);
basic_block* basic_block::create(context &ctx, const std::string &name, function *parent, basic_block* next){
return new basic_block(ctx, name, parent, next);
}
void basic_block::add_predecessor(basic_block *pred) {
preds_.push_back(pred);
if(pred)
pred->succs_.push_back(this);
void basic_block::replace_phi_uses_with(basic_block* before, basic_block* after) {
for(ir::instruction* i: inst_list_){
auto* curr_phi = dynamic_cast<ir::phi_node*>(i);
if(!curr_phi)
break;
curr_phi->replace_uses_of_with(before, after);
}
}
void basic_block::append_instruction(ir::instruction* i){
i->set_parent(this);
inst_list_.push_back(i);
}
basic_block* basic_block::split_before(ir::instruction* loc, const std::string& name) {
basic_block* ret = basic_block::create(ctx_, name, parent_, this);
ret->set_name(get_name());
set_name("after_" + name);
// splice instruction list
auto loc_it = std::find(inst_list_.begin(), inst_list_.end(), loc);
ret->get_inst_list().splice(ret->get_inst_list().begin(), inst_list_, inst_list_.begin(), loc_it);
for(ir::instruction* i: ret->get_inst_list())
i->set_parent(ret);
// the predecessors of `this` becomes the predecessors of `ret`
for(ir::basic_block* pred: get_predecessors()){
auto* term = dynamic_cast<ir::terminator_inst*>(pred->get_inst_list().back());
assert(term);
term->replace_uses_of_with(this, ret);
replace_phi_uses_with(pred, ret);
}
ir::branch_inst* br = branch_inst::create(this);
ret->append_instruction(br);
return ret;
}
std::vector<basic_block*> basic_block::get_predecessors() const {
std::vector<basic_block*> ret;
for(ir::user* u: users_)
if(auto term = dynamic_cast<ir::terminator_inst*>(u))
ret.push_back(term->get_parent());
return ret;
}
std::vector<basic_block*> basic_block::get_successors() const {
std::vector<basic_block*> ret;
for(ir::instruction* i: inst_list_)
for(ir::value* v: i->ops())
if(auto block = dynamic_cast<ir::basic_block*>(v))
ret.push_back(block);
return ret;
}
basic_block::iterator basic_block::get_first_non_phi(){
auto it = begin();

View File

@@ -117,13 +117,10 @@ type *builder::get_double_ty()
//===----------------------------------------------------------------------===//
value* builder::create_br(basic_block *dest){
dest->add_predecessor(block_);
return insert(branch_inst::create(dest));
}
value* builder::create_cond_br(value *cond, basic_block *if_dest, basic_block *else_dest){
if_dest->add_predecessor(block_);
else_dest->add_predecessor(block_);
return insert(branch_inst::create(cond, if_dest, else_dest));
}
@@ -131,6 +128,10 @@ value *builder::create_ret_void() {
return insert(return_inst::create(ctx_));
}
value *builder::create_ret(value* val) {
return insert(return_inst::create(ctx_, val));
}
//===----------------------------------------------------------------------===//
// cast instructions
//===----------------------------------------------------------------------===//
@@ -163,6 +164,19 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved){
return insert(phi_node::create(ty, num_reserved));
}
//===----------------------------------------------------------------------===//
// call instructions
//===----------------------------------------------------------------------===//
value *builder::create_call(function* fn, const std::vector<value*>& args){
return insert(call_inst::create(fn, args));
}
value* builder::create_launch(function* fn, const std::vector<value*>& args, const std::vector<value*>& grid, value* num_warps){
return insert(launch_inst::create(fn, args, grid, num_warps));
}
//===----------------------------------------------------------------------===//
// binary float instructions
//===----------------------------------------------------------------------===//
@@ -307,6 +321,19 @@ value *builder::create_masked_store(value *ptr, value *val, value *mask){
return insert(masked_store_inst::create(ptr, val, mask));
}
//===----------------------------------------------------------------------===//
// struct instructions
//===----------------------------------------------------------------------===//
// Struct instructions
value *builder::create_insert_value(value* val, value *elt, size_t idx){
return insert(insert_value_inst::create(val, elt, idx));
}
value *builder::create_extract_value(value* val, size_t idx) {
return insert(extract_value_inst::create(val, idx));
}
//===----------------------------------------------------------------------===//
// block instructions
//===----------------------------------------------------------------------===//

View File

@@ -865,6 +865,9 @@ ir::value *dispatch::clock(ir::builder *builder) {
}
//===----------------------------------------------------------------------===//
// Control FLow
//===----------------------------------------------------------------------===//
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){

View File

@@ -33,8 +33,10 @@ void argument::accept(visitor *v) {
/* function */
function::function(function_type *ty, linkage_types_t linkage,
const std::string &name, module *parent)
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty) {
: global_object(ty, 0, linkage, name), parent_(parent), fn_ty_(ty), is_kernel_(false) {
unsigned num_params = fn_ty_->get_num_params();
if(parent)
parent->push_function(this);
// skip if no parameter
if(num_params == 0)
return;
@@ -44,8 +46,6 @@ function::function(function_type *ty, linkage_types_t linkage,
type *param_ty = fn_ty_->get_param_ty(i);
args_[i] = argument::create(param_ty, "", this, i);
}
if(parent)
parent->push_function(this);
}
/* basic block */

View File

@@ -5,6 +5,7 @@
#include "triton/ir/instructions.h"
#include "triton/ir/constant.h"
#include "triton/ir/type.h"
#include "triton/ir/function.h"
namespace triton{
namespace ir{
@@ -79,6 +80,70 @@ phi_node* phi_node::create(type *ty, unsigned num_reserved, const std::string &n
return new phi_node(ty, num_reserved, name, next);
}
//===----------------------------------------------------------------------===//
// call_inst classes
//===----------------------------------------------------------------------===//
std::string call_inst::repr_impl() const { return "call " + fn_->get_name(); }
call_inst::call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_CALL, values.size(), name, next), fn_(fn){
for(size_t i = 0; i < values.size(); i++)
set_operand(i, values.at(i));
}
call_inst* call_inst::create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name, instruction *next) {
return new call_inst(fn, values, name, next);
}
// launch
launch_inst::launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, const std::string& name, instruction* next)
: instruction(fn->get_fn_type()->get_return_ty(), INST_LAUNCH, 1 + values.size() + grid.size() + 1, name, next){
int k = 0;
if(grid.size() != 3)
throw std::runtime_error("grid must have 3 elements");
set_operand(k++, fn);
val_begin = k;
for(ir::value* v: values)
set_operand(k++, v);
val_end = k;
grid_begin = k;
for(ir::value* g: grid)
set_operand(k++, g);
grid_end = k;
set_operand(k++, num_warps);
}
ir::function* launch_inst::get_fn() {
return (ir::function*)get_operand(0);
}
std::vector<ir::value*> launch_inst::get_values() {
std::vector<ir::value*> ret;
for(int i = val_begin; i < val_end; i++)
ret.push_back(get_operand(i));
return ret;
}
std::vector<ir::value*> launch_inst::get_grid() {
std::vector<ir::value*> ret;
for(int i = grid_begin; i < grid_end; i++)
ret.push_back(get_operand(i));
return ret;
}
ir::value* launch_inst::get_num_warps() {
return get_operand(grid_end);
}
launch_inst* launch_inst::create(ir::function *fn, const std::vector<ir::value *> &values, const std::vector<ir::value *> &grid, ir::value *num_warps, const std::string &name, instruction *next) {
return new launch_inst(fn, values, grid, num_warps, name, next);
}
//===----------------------------------------------------------------------===//
// binary_operator classes
@@ -324,7 +389,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
: terminator_inst(ret_val?ret_val->get_type():type::get_void_ty(ctx), INST_RETURN, ret_val!=nullptr, "", next){
if(ret_val)
set_operand(0, ret_val);
}
@@ -521,6 +586,36 @@ masked_store_inst::masked_store_inst(value *ptr, value *val, value *mask,
masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask, const std::string &name, instruction *next) {
return new masked_store_inst(ptr, val, mask, name, next);
}
//===----------------------------------------------------------------------===//
// struct classes
//===----------------------------------------------------------------------===//
// insert value
insert_value_inst::insert_value_inst(value *val, value *elt, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type(), INST_INSERT_VALUE, 2, name, next), idx_(idx) {
set_operand(0, val);
set_operand(1, elt);
}
insert_value_inst* insert_value_inst::create(value *val, value *elt, size_t idx, const std::string& name, instruction *next){
return new insert_value_inst(val, elt, idx, name, next);
}
// extract value
extract_value_inst::extract_value_inst(value *val, size_t idx, const std::string& name, instruction *next)
: instruction(val->get_type()->get_struct_type(idx), INST_EXTRACT_VALUE, 1, name, next), idx_(idx) {
set_operand(0, val);
}
extract_value_inst* extract_value_inst::create(value *val, size_t idx, const std::string& name, instruction *next){
return new extract_value_inst(val, idx, name, next);
}
//===----------------------------------------------------------------------===//
// retile_inst classes
//===----------------------------------------------------------------------===//
@@ -575,6 +670,9 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
return new downcast_inst(arg->get_type()->get_scalar_ty(), INST_DOWNCAST, arg, name, next);
}
//===----------------------------------------------------------------------===//
// matmul_inst classes
//===----------------------------------------------------------------------===//

View File

@@ -9,17 +9,12 @@
namespace triton{
namespace ir{
/* Module */
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
/* */
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
sealed_blocks_.insert(nullptr);
}
ir::builder& module::get_builder() {
return builder_;
}
void module::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
@@ -29,23 +24,11 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu
// value->set_name(name);
}
void module::set_value(const std::string& name, ir::value *value){
void value_constructor::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
void module::set_const(const std::string& name){
const_.insert(name);
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
builder_.set_insert_point(insert);
@@ -56,7 +39,7 @@ ir::phi_node* module::make_phi(ir::type *ty, unsigned num_values, ir::basic_bloc
return res;
}
ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
@@ -69,7 +52,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::set<ir::user*> users = phi->get_users();
std::vector<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
@@ -78,7 +61,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
}
ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi){
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
return phi;
@@ -90,12 +73,11 @@ ir::value *module::add_phi_operands(const std::string& name, ir::phi_node *&phi)
return phi;
}
ir::value *module::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
auto preds = block->get_predecessors();
ir::type *ty = types_.at(name);
if(block && !is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
@@ -117,10 +99,12 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
return result;
}
ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
// std::cout << values_.size() << std::endl;
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
if(values_.find(key) != values_.end()){
return values_.at(key);
}
@@ -131,15 +115,11 @@ ir::value *module::get_value(const std::string& name, ir::basic_block *block) {
return result;
}
ir::value *module::get_value(const std::string& name) {
ir::value *value_constructor::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
const std::string& module::get_name() {
return name_;
}
void module::seal_block(ir::basic_block *block){
void value_constructor::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
if(get_value(x.first) == x.second)
@@ -149,11 +129,40 @@ void module::seal_block(ir::basic_block *block){
incomplete_phis_[block].clear();
}
/* Module */
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
}
void module::reset_ret_ty(const std::string& name, type* ty) {
get_function(name)->get_fn_type()->reset_ret_ty(ty);
}
ir::builder& module::get_builder() {
return builder_;
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
const std::string& module::get_name() {
return name_;
}
/* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name];
if(fn == nullptr)
return fn = function::create(ty, global_value::external, name, this);
if(fn == nullptr){
fn = function::create(ty, global_value::external, name, this);
}
return fn;
}

View File

@@ -188,7 +188,26 @@ bool composite_type::index_valid(value *idx) const{
}
//===----------------------------------------------------------------------===//
// tile_type class
// struct_type class
//===----------------------------------------------------------------------===//
struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
: composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
contained_tys_ = tys;
}
struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
assert(tys.size());
context_impl* impl = tys[0]->get_context().p_impl.get();
struct_type *& entry = impl->struct_tys[tys];
if(!entry)
entry = new struct_type(tys, is_packed);
return entry;
}
//===----------------------------------------------------------------------===//
// block_type class
//===----------------------------------------------------------------------===//
block_type::block_type(type *ty, const block_shapes_t &shapes)

View File

@@ -1,5 +1,6 @@
#include <cassert>
#include <iostream>
#include <algorithm>
#include "triton/ir/value.h"
#include "triton/ir/instructions.h"
@@ -17,11 +18,11 @@ value::value(type *ty, const std::string &name): ty_(ty){
}
void value::add_use(user *arg) {
users_.insert(arg);
users_.push_back(arg);
}
value::users_t::iterator value::erase_use(user *arg){
auto it = users_.find(arg);
auto it = std::find(users_.begin(), users_.end(), arg);
if(it == users_.end())
return it;
return users_.erase(it);